Spaces:
Paused
Paused
lanny xu
commited on
Commit
·
399f3c6
0
Parent(s):
Initial commit
Browse files- .env +2 -0
- COLAB_FILES_SUMMARY.md +305 -0
- COLAB_GPU_GUIDE.md +271 -0
- COLAB_OLLAMA_GUIDE.md +421 -0
- DEPLOYMENT_GUIDE.md +440 -0
- Dockerfile.gpu +58 -0
- GRAPHRAG_GUIDE.md +401 -0
- GRAPHRAG_INTEGRATION_SUMMARY.md +427 -0
- QUICKSTART.md +73 -0
- README.md +157 -0
- RERANKING_PRINCIPLES.md +359 -0
- __pycache__/config.cpython-310.pyc +0 -0
- __pycache__/config.cpython-313.pyc +0 -0
- __pycache__/document_processor.cpython-310.pyc +0 -0
- __pycache__/document_processor.cpython-313.pyc +0 -0
- __pycache__/entity_extractor.cpython-310.pyc +0 -0
- __pycache__/graph_indexer.cpython-310.pyc +0 -0
- __pycache__/graph_retriever.cpython-310.pyc +0 -0
- __pycache__/knowledge_graph.cpython-310.pyc +0 -0
- __pycache__/main.cpython-310.pyc +0 -0
- __pycache__/main.cpython-313.pyc +0 -0
- __pycache__/reranker.cpython-310.pyc +0 -0
- __pycache__/routers_and_graders.cpython-310.pyc +0 -0
- __pycache__/routers_and_graders.cpython-313.pyc +0 -0
- __pycache__/workflow_nodes.cpython-310.pyc +0 -0
- __pycache__/workflow_nodes.cpython-313.pyc +0 -0
- colab_gpu_demo.ipynb +588 -0
- colab_gpu_test.py +269 -0
- colab_quick_test.py +278 -0
- colab_setup_and_run.py +375 -0
- config.py +87 -0
- deploy_gpu.sh +240 -0
- docker-compose.gpu.yml +70 -0
- document_processor.py +195 -0
- entity_extractor.py +229 -0
- graph_indexer.py +145 -0
- graph_retriever.py +275 -0
- knowledge_graph.py +347 -0
- local_llm_rag.py +428 -0
- main.py +174 -0
- main_graphrag.py +293 -0
- requirements.txt +44 -0
- requirements_gpu.txt +56 -0
- requirements_graphrag.txt +31 -0
- reranker.py +350 -0
- routers_and_graders.py +147 -0
- test_reranking.py +218 -0
- workflow_nodes.py +240 -0
.env
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TAVILY_API_KEY="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"
|
| 2 |
+
# NOMIC_API_KEY="nk-kt4Tu3UdwFpIlDdxLcd9AK3a7cfdAKhoXvPbJ78oVlE"
|
COLAB_FILES_SUMMARY.md
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 📦 Google Colab GPU测试文件总结
|
| 2 |
+
|
| 3 |
+
## ✅ 已创建的文件
|
| 4 |
+
|
| 5 |
+
| 文件名 | 类型 | 用途 | 推荐度 |
|
| 6 |
+
|--------|------|------|--------|
|
| 7 |
+
| **colab_gpu_demo.ipynb** | Jupyter Notebook | 完整的交互式GPU测试 | ⭐⭐⭐⭐⭐ |
|
| 8 |
+
| **colab_quick_test.py** | Python脚本 | 一键快速GPU测试 | ⭐⭐⭐⭐⭐ |
|
| 9 |
+
| **colab_gpu_test.py** | Python脚本 | 模块化GPU测试工具 | ⭐⭐⭐⭐ |
|
| 10 |
+
| **COLAB_GPU_GUIDE.md** | 文档 | 详细使用指南 | ⭐⭐⭐⭐⭐ |
|
| 11 |
+
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
## 🚀 快速开始(3种方式)
|
| 15 |
+
|
| 16 |
+
### 方式1: Notebook交互式测试 ⭐推荐
|
| 17 |
+
|
| 18 |
+
**适合**: 第一次使用,想要详细了解每个步骤
|
| 19 |
+
|
| 20 |
+
```bash
|
| 21 |
+
# 步骤1: 上传文件
|
| 22 |
+
上传 colab_gpu_demo.ipynb 到 Google Colab
|
| 23 |
+
|
| 24 |
+
# 步骤2: 启用GPU
|
| 25 |
+
运行时 → 更改运行时类型 → GPU
|
| 26 |
+
|
| 27 |
+
# 步骤3: 运行
|
| 28 |
+
运行时 → 全部运行
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
**优势**:
|
| 32 |
+
- ✅ 可视化输出
|
| 33 |
+
- ✅ 分步执行,易于理解
|
| 34 |
+
- ✅ 支持实时修改
|
| 35 |
+
- ✅ Markdown说明清晰
|
| 36 |
+
|
| 37 |
+
---
|
| 38 |
+
|
| 39 |
+
### 方式2: 快速一键测试 ⭐最快
|
| 40 |
+
|
| 41 |
+
**适合**: 快速验证GPU性能
|
| 42 |
+
|
| 43 |
+
```python
|
| 44 |
+
# 在Colab新建笔记本,运行以下代码:
|
| 45 |
+
|
| 46 |
+
# 1. 启用GPU (运行时 → GPU)
|
| 47 |
+
|
| 48 |
+
# 2. 复制并运行
|
| 49 |
+
!wget https://your-repo/colab_quick_test.py
|
| 50 |
+
!python colab_quick_test.py
|
| 51 |
+
|
| 52 |
+
# 或直接复制代码到单元格运行
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
**优势**:
|
| 56 |
+
- ✅ 零配置
|
| 57 |
+
- ✅ 自动安装依赖
|
| 58 |
+
- ✅ 5分钟完成全部测试
|
| 59 |
+
- ✅ 一次性输出完整报告
|
| 60 |
+
|
| 61 |
+
---
|
| 62 |
+
|
| 63 |
+
### 方式3: 模块化测试工具
|
| 64 |
+
|
| 65 |
+
**适合**: 开发者深度定制
|
| 66 |
+
|
| 67 |
+
```python
|
| 68 |
+
# 在Colab中
|
| 69 |
+
!wget https://your-repo/colab_gpu_test.py
|
| 70 |
+
!python colab_gpu_test.py
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
**优势**:
|
| 74 |
+
- ✅ 代码结构清晰
|
| 75 |
+
- ✅ 易于扩展
|
| 76 |
+
- ✅ 可集成到其他项目
|
| 77 |
+
|
| 78 |
+
---
|
| 79 |
+
|
| 80 |
+
## 📊 测试内容对比
|
| 81 |
+
|
| 82 |
+
| 测试项目 | Notebook | Quick Test | GPU Test |
|
| 83 |
+
|---------|----------|------------|----------|
|
| 84 |
+
| GPU环境检测 | ✅ | ✅ | ✅ |
|
| 85 |
+
| 矩阵运算测试 | ✅ | ✅ | ✅ |
|
| 86 |
+
| 文本嵌入测试 | ✅ | ✅ | ✅ |
|
| 87 |
+
| GraphRAG组件 | ✅ | ❌ | ❌ |
|
| 88 |
+
| 显存监控 | ✅ | ✅ | ✅ |
|
| 89 |
+
| 性能报告 | ✅ | ✅ | ✅ |
|
| 90 |
+
| 交互式说明 | ✅ | ❌ | ❌ |
|
| 91 |
+
| nvidia-smi | ✅ | ✅ | ✅ |
|
| 92 |
+
|
| 93 |
+
---
|
| 94 |
+
|
| 95 |
+
## 🎯 使用场景推荐
|
| 96 |
+
|
| 97 |
+
### 场景1: 首次测试GPU
|
| 98 |
+
**推荐**: `colab_gpu_demo.ipynb`
|
| 99 |
+
- 详细的说明文档
|
| 100 |
+
- 分步执行,便于学习
|
| 101 |
+
- 可视化效果好
|
| 102 |
+
|
| 103 |
+
### 场景2: 快速验证性能
|
| 104 |
+
**推荐**: `colab_quick_test.py`
|
| 105 |
+
- 一键运行
|
| 106 |
+
- 5分钟得到结果
|
| 107 |
+
- 完整性能报告
|
| 108 |
+
|
| 109 |
+
### 场景3: 集成到CI/CD
|
| 110 |
+
**推荐**: `colab_gpu_test.py`
|
| 111 |
+
- 模块化设计
|
| 112 |
+
- 易于自动化
|
| 113 |
+
- 返回标准化结果
|
| 114 |
+
|
| 115 |
+
### 场景4: 学习GPU优化
|
| 116 |
+
**推荐**: `COLAB_GPU_GUIDE.md` + `colab_gpu_demo.ipynb`
|
| 117 |
+
- 理论+实践
|
| 118 |
+
- 详细的性能分析
|
| 119 |
+
- 优化建议
|
| 120 |
+
|
| 121 |
+
---
|
| 122 |
+
|
| 123 |
+
## 📈 预期性能提升
|
| 124 |
+
|
| 125 |
+
### Google Colab T4 GPU (免费版)
|
| 126 |
+
|
| 127 |
+
| 任务 | CPU | GPU | 加速比 |
|
| 128 |
+
|------|-----|-----|--------|
|
| 129 |
+
| 矩阵运算 (5000x5000) | 8秒 | 0.3秒 | **25x** |
|
| 130 |
+
| 文本嵌入 (1000条) | 35秒 | 6秒 | **6x** |
|
| 131 |
+
| GraphRAG索引 (100文档) | 15分钟 | 4分钟 | **3.8x** |
|
| 132 |
+
|
| 133 |
+
### Google Colab A100 GPU (Pro版)
|
| 134 |
+
|
| 135 |
+
| 任务 | CPU | GPU | 加速比 |
|
| 136 |
+
|------|-----|-----|--------|
|
| 137 |
+
| 矩阵运算 | 8秒 | 0.2秒 | **40x** |
|
| 138 |
+
| 文本嵌入 | 35秒 | 3秒 | **12x** |
|
| 139 |
+
| GraphRAG索引 | 15分钟 | 2.5分钟 | **6x** |
|
| 140 |
+
|
| 141 |
+
---
|
| 142 |
+
|
| 143 |
+
## 🔧 完整GraphRAG部署流程
|
| 144 |
+
|
| 145 |
+
### 步骤1: GPU性能测试
|
| 146 |
+
```python
|
| 147 |
+
# 运行quick test验证GPU
|
| 148 |
+
!python colab_quick_test.py
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
### 步骤2: 上传项目文件
|
| 152 |
+
```python
|
| 153 |
+
# 方式A: 从Google Drive
|
| 154 |
+
from google.colab import drive
|
| 155 |
+
drive.mount('/content/drive')
|
| 156 |
+
!cp -r /content/drive/MyDrive/adaptive_RAG /content/
|
| 157 |
+
%cd /content/adaptive_RAG
|
| 158 |
+
|
| 159 |
+
# 方式B: 从GitHub
|
| 160 |
+
!git clone YOUR_REPO_URL
|
| 161 |
+
%cd adaptive_RAG
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
### 步骤3: 安装依赖
|
| 165 |
+
```python
|
| 166 |
+
!pip install -q -r requirements.txt
|
| 167 |
+
!pip install -q -r requirements_graphrag.txt
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
### 步骤4: 配置API密钥
|
| 171 |
+
```python
|
| 172 |
+
import os
|
| 173 |
+
from getpass import getpass
|
| 174 |
+
os.environ['TAVILY_API_KEY'] = getpass('TAVILY_API_KEY: ')
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
### 步骤5: 运行GraphRAG
|
| 178 |
+
```python
|
| 179 |
+
!python main_graphrag.py
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
### 步骤6: 下载结果
|
| 183 |
+
```python
|
| 184 |
+
from google.colab import files
|
| 185 |
+
files.download('data/knowledge_graph.json')
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
---
|
| 189 |
+
|
| 190 |
+
## 💡 优化技巧
|
| 191 |
+
|
| 192 |
+
### 1. 批处理大小
|
| 193 |
+
```python
|
| 194 |
+
# config.py
|
| 195 |
+
GRAPHRAG_BATCH_SIZE = 20 # GPU环境可增大
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
### 2. 嵌入模型选择
|
| 199 |
+
```python
|
| 200 |
+
# GPU环境使用更大模型
|
| 201 |
+
EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
|
| 202 |
+
```
|
| 203 |
+
|
| 204 |
+
### 3. 混合精度训练
|
| 205 |
+
```python
|
| 206 |
+
import torch
|
| 207 |
+
torch.set_float32_matmul_precision('medium')
|
| 208 |
+
```
|
| 209 |
+
|
| 210 |
+
### 4. 数据持久化
|
| 211 |
+
```python
|
| 212 |
+
# 定期保存到Drive
|
| 213 |
+
import shutil
|
| 214 |
+
shutil.copy(
|
| 215 |
+
'data/knowledge_graph.json',
|
| 216 |
+
'/content/drive/MyDrive/backup.json'
|
| 217 |
+
)
|
| 218 |
+
```
|
| 219 |
+
|
| 220 |
+
---
|
| 221 |
+
|
| 222 |
+
## ⚠️ 注意事项
|
| 223 |
+
|
| 224 |
+
### Colab免费版限制
|
| 225 |
+
- ⏰ 连续使用: 最多12小时
|
| 226 |
+
- 🔄 GPU配额: 每周有限
|
| 227 |
+
- ⏸️ 闲置超时: 90分钟
|
| 228 |
+
|
| 229 |
+
### 建议
|
| 230 |
+
- 💾 定期保存进度
|
| 231 |
+
- ⬇️ 及时下载结果
|
| 232 |
+
- 🔄 使用后台任务保持活跃
|
| 233 |
+
|
| 234 |
+
---
|
| 235 |
+
|
| 236 |
+
## 📚 文件使用优先级
|
| 237 |
+
|
| 238 |
+
### 新手用户
|
| 239 |
+
1. 📖 先阅读 `COLAB_GPU_GUIDE.md`
|
| 240 |
+
2. 🚀 运行 `colab_gpu_demo.ipynb`
|
| 241 |
+
3. ✅ 验证性能后部署完整项目
|
| 242 |
+
|
| 243 |
+
### 高级用户
|
| 244 |
+
1. ⚡ 直接运行 `colab_quick_test.py`
|
| 245 |
+
2. 📊 查看性能报告
|
| 246 |
+
3. 🔧 根据需求调整配置
|
| 247 |
+
|
| 248 |
+
### 开发者
|
| 249 |
+
1. 🔍 研究 `colab_gpu_test.py` 源码
|
| 250 |
+
2. 🛠️ 根据需求定制功能
|
| 251 |
+
3. 🔄 集成到自动化流程
|
| 252 |
+
|
| 253 |
+
---
|
| 254 |
+
|
| 255 |
+
## 🎯 关键性能指标
|
| 256 |
+
|
| 257 |
+
### 必须达到的基准
|
| 258 |
+
- ✅ GPU检测: CUDA可用
|
| 259 |
+
- ✅ 矩阵加速: >10x
|
| 260 |
+
- ✅ 嵌入加速: >5x
|
| 261 |
+
- ✅ 显存使用: <80%
|
| 262 |
+
|
| 263 |
+
### 如果低于基准
|
| 264 |
+
1. 检查GPU类型 (应该是T4或A100)
|
| 265 |
+
2. 重启运行时
|
| 266 |
+
3. 检查依赖版本
|
| 267 |
+
|
| 268 |
+
---
|
| 269 |
+
|
| 270 |
+
## 📞 获取帮助
|
| 271 |
+
|
| 272 |
+
### 常见问题
|
| 273 |
+
- 查看 `COLAB_GPU_GUIDE.md` 的FAQ部分
|
| 274 |
+
|
| 275 |
+
### 性能问题
|
| 276 |
+
- 运行 `colab_quick_test.py` 获取诊断报告
|
| 277 |
+
|
| 278 |
+
### 技术支持
|
| 279 |
+
- 提供测试报告输出
|
| 280 |
+
- 说明具体错误信息
|
| 281 |
+
|
| 282 |
+
---
|
| 283 |
+
|
| 284 |
+
## ✅ 总结
|
| 285 |
+
|
| 286 |
+
| 文件 | 何时使用 |
|
| 287 |
+
|------|---------|
|
| 288 |
+
| `colab_gpu_demo.ipynb` | 首次使用、学习、演示 |
|
| 289 |
+
| `colab_quick_test.py` | 快速验证、CI/CD、批量测试 |
|
| 290 |
+
| `colab_gpu_test.py` | 深度定制、集成开发 |
|
| 291 |
+
| `COLAB_GPU_GUIDE.md` | 参考文档、问题排查 |
|
| 292 |
+
|
| 293 |
+
**推荐流程**:
|
| 294 |
+
1. 阅读 `COLAB_GPU_GUIDE.md` (5分钟)
|
| 295 |
+
2. 运行 `colab_quick_test.py` (5分钟)
|
| 296 |
+
3. 如果性能符合预期,部署完整GraphRAG项目
|
| 297 |
+
|
| 298 |
+
**预期结果**:
|
| 299 |
+
- GPU可用 ✅
|
| 300 |
+
- 3-6倍整体加速 ✅
|
| 301 |
+
- 节省10+分钟时间 ✅
|
| 302 |
+
|
| 303 |
+
---
|
| 304 |
+
|
| 305 |
+
🚀 **立即开始**: 上传任一文件到 [Google Colab](https://colab.research.google.com/) 并启用GPU!
|
COLAB_GPU_GUIDE.md
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚀 Google Colab GPU 测试指南
|
| 2 |
+
|
| 3 |
+
## 📋 概述
|
| 4 |
+
|
| 5 |
+
我为您创建了两个文件用于在Google Colab上测试GPU性能:
|
| 6 |
+
|
| 7 |
+
1. **`colab_gpu_demo.ipynb`** - Jupyter Notebook版本(推荐)
|
| 8 |
+
2. **`colab_gpu_test.py`** - Python脚本版本
|
| 9 |
+
|
| 10 |
+
## 🎯 使用方法
|
| 11 |
+
|
| 12 |
+
### 方法1: 使用Notebook(推荐)
|
| 13 |
+
|
| 14 |
+
#### 步骤1: 上传到Colab
|
| 15 |
+
|
| 16 |
+
1. 打开 [Google Colab](https://colab.research.google.com/)
|
| 17 |
+
2. 点击 `文件` → `上传笔记本`
|
| 18 |
+
3. 选择 `colab_gpu_demo.ipynb`
|
| 19 |
+
|
| 20 |
+
#### 步骤2: 启用GPU
|
| 21 |
+
|
| 22 |
+
1. 点击顶部菜单 `运行时` → `更改运行时类型`
|
| 23 |
+
2. 硬件加速器选择 `GPU`
|
| 24 |
+
3. GPU类型选择 `T4`(免费版)或 `A100`(Colab Pro)
|
| 25 |
+
4. 点击 `保存`
|
| 26 |
+
|
| 27 |
+
#### 步骤3: 运行测试
|
| 28 |
+
|
| 29 |
+
1. 点击 `运行时` → `全部运行`
|
| 30 |
+
2. 或者逐个单元格运行(Shift + Enter)
|
| 31 |
+
|
| 32 |
+
### 方法2: 使用Python脚本
|
| 33 |
+
|
| 34 |
+
#### 步骤1: 上传文件
|
| 35 |
+
|
| 36 |
+
1. 在Colab中创建新笔记本
|
| 37 |
+
2. 点击左侧文件夹图标
|
| 38 |
+
3. 上传 `colab_gpu_test.py`
|
| 39 |
+
|
| 40 |
+
#### 步骤2: 运行脚本
|
| 41 |
+
|
| 42 |
+
```python
|
| 43 |
+
# 在Colab单元格中运行
|
| 44 |
+
!python colab_gpu_test.py
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
## 📊 测试内容
|
| 48 |
+
|
| 49 |
+
### 1. GPU环境检测 ✅
|
| 50 |
+
- CUDA可用性检查
|
| 51 |
+
- GPU型号和显存信息
|
| 52 |
+
- nvidia-smi输出
|
| 53 |
+
|
| 54 |
+
### 2. 矩阵运算性能测试 ⚡
|
| 55 |
+
- CPU vs GPU 5000x5000矩阵乘法
|
| 56 |
+
- 预期加速比: **10-50x**
|
| 57 |
+
|
| 58 |
+
### 3. 文本嵌入性能测试 📝
|
| 59 |
+
- 使用sentence-transformers
|
| 60 |
+
- 1000个文本的嵌入生成
|
| 61 |
+
- CPU vs GPU对比
|
| 62 |
+
- 预期加速比: **5-10x**
|
| 63 |
+
|
| 64 |
+
### 4. GraphRAG组件测试 🔍
|
| 65 |
+
- 简化版知识图谱构建
|
| 66 |
+
- 实体和关系管理
|
| 67 |
+
- GPU加速的向量检索
|
| 68 |
+
|
| 69 |
+
### 5. 显存监控 💾
|
| 70 |
+
- 实时显存使用情况
|
| 71 |
+
- 内存分配统计
|
| 72 |
+
|
| 73 |
+
## 📈 预期结果
|
| 74 |
+
|
| 75 |
+
### Google Colab 免费版 (T4 GPU)
|
| 76 |
+
|
| 77 |
+
| 测试项目 | CPU时间 | GPU时间 | 加速比 |
|
| 78 |
+
|---------|---------|---------|--------|
|
| 79 |
+
| 矩阵运算 (5000x5000) | ~8-10秒 | ~0.3-0.5秒 | 20-30x |
|
| 80 |
+
| 文本嵌入 (1000文本) | ~30-40秒 | ~5-8秒 | 5-7x |
|
| 81 |
+
| GraphRAG索引 (100文档) | ~15分钟 | ~3-5分钟 | 3-5x |
|
| 82 |
+
|
| 83 |
+
### Google Colab Pro (A100 GPU)
|
| 84 |
+
|
| 85 |
+
| 测试项目 | CPU时间 | GPU时间 | 加速比 |
|
| 86 |
+
|---------|---------|---------|--------|
|
| 87 |
+
| 矩阵运算 | ~8秒 | ~0.2秒 | 40x |
|
| 88 |
+
| 文本嵌入 | ~35秒 | ~3秒 | 10-12x |
|
| 89 |
+
| GraphRAG索引 | ~15分钟 | ~2-3分钟 | 5-7x |
|
| 90 |
+
|
| 91 |
+
## 🔧 运行完整GraphRAG项目
|
| 92 |
+
|
| 93 |
+
如果GPU测试成功,可以在Colab上运行完整的GraphRAG项目:
|
| 94 |
+
|
| 95 |
+
### 步骤1: 上传项目文件
|
| 96 |
+
|
| 97 |
+
在Colab中创建新的单元格:
|
| 98 |
+
|
| 99 |
+
```python
|
| 100 |
+
# 方式1: 从Google Drive加载
|
| 101 |
+
from google.colab import drive
|
| 102 |
+
drive.mount('/content/drive')
|
| 103 |
+
|
| 104 |
+
# 复制项目文件
|
| 105 |
+
!cp -r /content/drive/MyDrive/adaptive_RAG /content/
|
| 106 |
+
%cd /content/adaptive_RAG
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
或者:
|
| 110 |
+
|
| 111 |
+
```python
|
| 112 |
+
# 方式2: 从GitHub克隆
|
| 113 |
+
!git clone YOUR_GITHUB_REPO_URL
|
| 114 |
+
%cd adaptive_RAG
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
### 步骤2: 安装依赖
|
| 118 |
+
|
| 119 |
+
```python
|
| 120 |
+
# 安装基础依赖
|
| 121 |
+
!pip install -q -r requirements.txt
|
| 122 |
+
|
| 123 |
+
# 安装GraphRAG依赖
|
| 124 |
+
!pip install -q -r requirements_graphrag.txt
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
### 步骤3: 配置API密钥
|
| 128 |
+
|
| 129 |
+
```python
|
| 130 |
+
import os
|
| 131 |
+
from getpass import getpass
|
| 132 |
+
|
| 133 |
+
# 安全输入API密钥
|
| 134 |
+
os.environ['TAVILY_API_KEY'] = getpass('输入 TAVILY_API_KEY: ')
|
| 135 |
+
|
| 136 |
+
# 验证
|
| 137 |
+
print("✅ API密钥已设置")
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
### 步骤4: 运行GraphRAG
|
| 141 |
+
|
| 142 |
+
```python
|
| 143 |
+
# 运行主程序
|
| 144 |
+
!python main_graphrag.py
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
### 步骤5: 下载结果
|
| 148 |
+
|
| 149 |
+
```python
|
| 150 |
+
# 下载构建好的知识图谱
|
| 151 |
+
from google.colab import files
|
| 152 |
+
|
| 153 |
+
# 下载图谱文件
|
| 154 |
+
files.download('data/knowledge_graph.json')
|
| 155 |
+
|
| 156 |
+
print("✅ 图谱已下载到本地")
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
## 💡 优化建议
|
| 160 |
+
|
| 161 |
+
### 1. 批处理大小优化
|
| 162 |
+
|
| 163 |
+
在 `config.py` 中调整:
|
| 164 |
+
|
| 165 |
+
```python
|
| 166 |
+
# GPU优化配置
|
| 167 |
+
GRAPHRAG_BATCH_SIZE = 20 # GPU可以处理更大批次
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
### 2. 使用GPU优化的模型
|
| 171 |
+
|
| 172 |
+
```python
|
| 173 |
+
# 使用更大的嵌入模型(GPU环境)
|
| 174 |
+
EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
### 3. 启用混合精度
|
| 178 |
+
|
| 179 |
+
```python
|
| 180 |
+
# 在entity_extractor.py中
|
| 181 |
+
import torch
|
| 182 |
+
torch.set_float32_matmul_precision('medium') # 提升性能
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
## ⚠️ 注意事项
|
| 186 |
+
|
| 187 |
+
### Colab资源限制
|
| 188 |
+
|
| 189 |
+
1. **免费版限制**:
|
| 190 |
+
- 连续使用时间: 最多12小时
|
| 191 |
+
- GPU使用配额: 每周有限
|
| 192 |
+
- 闲置超时: 90分钟自动断开
|
| 193 |
+
|
| 194 |
+
2. **建议**:
|
| 195 |
+
- 定期保存进度到Google Drive
|
| 196 |
+
- 使用`files.download()`下载重要结果
|
| 197 |
+
- 避免长时间空闲
|
| 198 |
+
|
| 199 |
+
### 数据持久化
|
| 200 |
+
|
| 201 |
+
```python
|
| 202 |
+
# 定期保存到Google Drive
|
| 203 |
+
from google.colab import drive
|
| 204 |
+
drive.mount('/content/drive')
|
| 205 |
+
|
| 206 |
+
# 保存图谱
|
| 207 |
+
import shutil
|
| 208 |
+
shutil.copy(
|
| 209 |
+
'data/knowledge_graph.json',
|
| 210 |
+
'/content/drive/MyDrive/graphrag_backup.json'
|
| 211 |
+
)
|
| 212 |
+
```
|
| 213 |
+
|
| 214 |
+
## 🐛 常见问题
|
| 215 |
+
|
| 216 |
+
### Q1: GPU连接失败
|
| 217 |
+
|
| 218 |
+
**A**: 检查运行时类型
|
| 219 |
+
```python
|
| 220 |
+
import torch
|
| 221 |
+
print(f"CUDA可用: {torch.cuda.is_available()}")
|
| 222 |
+
# 如果False,重新设置运行时类型
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
### Q2: 内存不足
|
| 226 |
+
|
| 227 |
+
**A**: 减小批处理大小
|
| 228 |
+
```python
|
| 229 |
+
GRAPHRAG_BATCH_SIZE = 5 # 降低批次
|
| 230 |
+
```
|
| 231 |
+
|
| 232 |
+
### Q3: 会话超时
|
| 233 |
+
|
| 234 |
+
**A**: 使用Colab Pro或定期运行代码保持活跃
|
| 235 |
+
```python
|
| 236 |
+
# 在后台定期执行
|
| 237 |
+
import time
|
| 238 |
+
while True:
|
| 239 |
+
print("Keep alive...")
|
| 240 |
+
time.sleep(300) # 每5分钟执行一次
|
| 241 |
+
```
|
| 242 |
+
|
| 243 |
+
## 📚 参考资源
|
| 244 |
+
|
| 245 |
+
- [Google Colab官方文档](https://colab.research.google.com/notebooks/intro.ipynb)
|
| 246 |
+
- [GPU加速指南](https://colab.research.google.com/notebooks/gpu.ipynb)
|
| 247 |
+
- [Colab Pro定价](https://colab.research.google.com/signup)
|
| 248 |
+
|
| 249 |
+
## 🎓 下一步学习
|
| 250 |
+
|
| 251 |
+
1. **理解GPU加速原理**: 查看测试代码中的性能对比
|
| 252 |
+
2. **优化GraphRAG参数**: 根据GPU性能调整配置
|
| 253 |
+
3. **扩展到生产环境**: 考虑使用AWS/GCP的GPU实例
|
| 254 |
+
|
| 255 |
+
---
|
| 256 |
+
|
| 257 |
+
## ✅ 总结
|
| 258 |
+
|
| 259 |
+
| 优势 | 说明 |
|
| 260 |
+
|------|------|
|
| 261 |
+
| 🆓 免费GPU | T4 GPU免费使用 |
|
| 262 |
+
| ⚡ 高性能 | 3-10倍加速 |
|
| 263 |
+
| 🔄 零配置 | 无需本地安装 |
|
| 264 |
+
| 💾 自动保存 | 集成Google Drive |
|
| 265 |
+
| 🌐 随时访问 | 仅需浏览器 |
|
| 266 |
+
|
| 267 |
+
**推荐**: 在本地CPU环境速度慢时,使用Colab GPU可以大幅提升GraphRAG索引构建速度!
|
| 268 |
+
|
| 269 |
+
---
|
| 270 |
+
|
| 271 |
+
**立即开始**: 上传 `colab_gpu_demo.ipynb` 到 [Google Colab](https://colab.research.google.com/) 并启用GPU! 🚀
|
COLAB_OLLAMA_GUIDE.md
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GraphRAG Colab 完整运行指南
|
| 2 |
+
|
| 3 |
+
## 🎯 在Colab中运行Ollama的3种方法
|
| 4 |
+
|
| 5 |
+
### 方法1: 后台运行Ollama(推荐)⭐⭐⭐⭐⭐
|
| 6 |
+
|
| 7 |
+
在Colab中,您可以在单个单元格中后台启动Ollama,然后在另一个单元格运行GraphRAG。
|
| 8 |
+
|
| 9 |
+
#### 步骤1: 安装Ollama
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
# 单元格1: 安装Ollama
|
| 13 |
+
!curl -fsSL https://ollama.com/install.sh | sh
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
#### 步骤2: 后台启动Ollama服务
|
| 17 |
+
|
| 18 |
+
```python
|
| 19 |
+
# 单元格2: 后台启动Ollama
|
| 20 |
+
import subprocess
|
| 21 |
+
import time
|
| 22 |
+
import os
|
| 23 |
+
|
| 24 |
+
# 启动Ollama服务(后台)
|
| 25 |
+
ollama_process = subprocess.Popen(
|
| 26 |
+
["ollama", "serve"],
|
| 27 |
+
stdout=subprocess.PIPE,
|
| 28 |
+
stderr=subprocess.PIPE,
|
| 29 |
+
preexec_fn=os.setpgrp
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
print("⏳ 等待Ollama服务启动...")
|
| 33 |
+
time.sleep(5)
|
| 34 |
+
|
| 35 |
+
# 验证服务是否启动
|
| 36 |
+
!curl -s http://localhost:11434/api/tags | head -5
|
| 37 |
+
|
| 38 |
+
print(f"✅ Ollama服务已启动 (PID: {ollama_process.pid})")
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
#### 步骤3: 下载Mistral模型
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
# 单元格3: 下载模型
|
| 45 |
+
!ollama pull mistral
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
#### 步骤4: 安装Python依赖
|
| 49 |
+
|
| 50 |
+
```bash
|
| 51 |
+
# 单元格4: 安装依赖
|
| 52 |
+
!pip install -q langchain langchain-community langchain-core langgraph
|
| 53 |
+
!pip install -q chromadb sentence-transformers tiktoken
|
| 54 |
+
!pip install -q tavily-python python-dotenv networkx python-louvain
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
#### 步骤5: 配置API密钥
|
| 58 |
+
|
| 59 |
+
```python
|
| 60 |
+
# 单元格5: 配置环境
|
| 61 |
+
import os
|
| 62 |
+
from getpass import getpass
|
| 63 |
+
|
| 64 |
+
os.environ['TAVILY_API_KEY'] = getpass('输入TAVILY_API_KEY: ')
|
| 65 |
+
print("✅ API密钥已设置")
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
#### 步骤6: 运行GraphRAG
|
| 69 |
+
|
| 70 |
+
```python
|
| 71 |
+
# 单元格6: 运行GraphRAG
|
| 72 |
+
!python main_graphrag.py
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
#### 步骤7: 下载结果(可选)
|
| 76 |
+
|
| 77 |
+
```python
|
| 78 |
+
# 单元格7: 下载生成的图谱
|
| 79 |
+
from google.colab import files
|
| 80 |
+
files.download('data/knowledge_graph.json')
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
---
|
| 84 |
+
|
| 85 |
+
### 方法2: 使用tmux(高级)⭐⭐⭐⭐
|
| 86 |
+
|
| 87 |
+
```bash
|
| 88 |
+
# 单元格1: 安装tmux
|
| 89 |
+
!apt-get install -y tmux
|
| 90 |
+
|
| 91 |
+
# 单元格2: 在tmux会话中启动Ollama
|
| 92 |
+
!tmux new-session -d -s ollama 'ollama serve'
|
| 93 |
+
|
| 94 |
+
# 单元格3: 检查会话
|
| 95 |
+
!tmux ls
|
| 96 |
+
|
| 97 |
+
# 单元格4: 下载模型
|
| 98 |
+
!ollama pull mistral
|
| 99 |
+
|
| 100 |
+
# 单元格5: 运行GraphRAG
|
| 101 |
+
!python main_graphrag.py
|
| 102 |
+
|
| 103 |
+
# 单元格6: 停止tmux会话(清理)
|
| 104 |
+
!tmux kill-session -t ollama
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
---
|
| 108 |
+
|
| 109 |
+
### 方法3: 使用nohup(简单)⭐⭐⭐
|
| 110 |
+
|
| 111 |
+
```bash
|
| 112 |
+
# 单元格1: 后台启动Ollama
|
| 113 |
+
!nohup ollama serve > /tmp/ollama.log 2>&1 &
|
| 114 |
+
|
| 115 |
+
# 单元格2: 等待启动
|
| 116 |
+
import time
|
| 117 |
+
time.sleep(5)
|
| 118 |
+
|
| 119 |
+
# 单元格3: 检查日志
|
| 120 |
+
!tail -20 /tmp/ollama.log
|
| 121 |
+
|
| 122 |
+
# 单元格4: 下载模型
|
| 123 |
+
!ollama pull mistral
|
| 124 |
+
|
| 125 |
+
# 单元格5: 运行GraphRAG
|
| 126 |
+
!python main_graphrag.py
|
| 127 |
+
|
| 128 |
+
# 单元格6: 停止Ollama(清理)
|
| 129 |
+
!pkill -f 'ollama serve'
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
---
|
| 133 |
+
|
| 134 |
+
## 🚀 一键运行脚本(最简单)⭐⭐⭐⭐⭐
|
| 135 |
+
|
| 136 |
+
我已经为您创建了一个自动化脚本 `colab_setup_and_run.py`,它会:
|
| 137 |
+
1. ✅ 自动安装Ollama
|
| 138 |
+
2. ✅ 后台启动服务
|
| 139 |
+
3. ✅ 下载Mistral模型
|
| 140 |
+
4. ✅ 安装Python依赖
|
| 141 |
+
5. ✅ 配置环境变量
|
| 142 |
+
6. ✅ 运行GraphRAG
|
| 143 |
+
|
| 144 |
+
### 使用方法:
|
| 145 |
+
|
| 146 |
+
```bash
|
| 147 |
+
# 方法A: 直接运行脚本
|
| 148 |
+
!python colab_setup_and_run.py
|
| 149 |
+
|
| 150 |
+
# 方法B: 或者在Python中
|
| 151 |
+
import subprocess
|
| 152 |
+
subprocess.run(["python", "colab_setup_and_run.py"])
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
---
|
| 156 |
+
|
| 157 |
+
## 📊 完整的Colab Notebook示例
|
| 158 |
+
|
| 159 |
+
创建一个新的Colab笔记本,按顺序运行以下单元格:
|
| 160 |
+
|
| 161 |
+
### 单元格1: 环境准备
|
| 162 |
+
|
| 163 |
+
```python
|
| 164 |
+
# 检测GPU
|
| 165 |
+
import torch
|
| 166 |
+
print(f"GPU可用: {torch.cuda.is_available()}")
|
| 167 |
+
if torch.cuda.is_available():
|
| 168 |
+
print(f"GPU型号: {torch.cuda.get_device_name(0)}")
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
### 单元格2: 安装Ollama
|
| 172 |
+
|
| 173 |
+
```bash
|
| 174 |
+
%%bash
|
| 175 |
+
curl -fsSL https://ollama.com/install.sh | sh
|
| 176 |
+
echo "✅ Ollama安装完成"
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
### 单元格3: 后台启动Ollama
|
| 180 |
+
|
| 181 |
+
```python
|
| 182 |
+
import subprocess
|
| 183 |
+
import time
|
| 184 |
+
import os
|
| 185 |
+
|
| 186 |
+
print("🔄 启动Ollama服务...")
|
| 187 |
+
|
| 188 |
+
# 后台启动
|
| 189 |
+
process = subprocess.Popen(
|
| 190 |
+
["ollama", "serve"],
|
| 191 |
+
stdout=subprocess.PIPE,
|
| 192 |
+
stderr=subprocess.PIPE,
|
| 193 |
+
preexec_fn=os.setpgrp
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# 等待启动
|
| 197 |
+
time.sleep(5)
|
| 198 |
+
|
| 199 |
+
# 验证
|
| 200 |
+
import requests
|
| 201 |
+
try:
|
| 202 |
+
response = requests.get("http://localhost:11434/api/tags", timeout=3)
|
| 203 |
+
if response.status_code == 200:
|
| 204 |
+
print(f"✅ Ollama服务运行正常 (PID: {process.pid})")
|
| 205 |
+
else:
|
| 206 |
+
print("⚠️ 服务响应异常")
|
| 207 |
+
except:
|
| 208 |
+
print("⚠️ 无法连接服务,但进程已启动")
|
| 209 |
+
|
| 210 |
+
# 保存进程ID(重要!)
|
| 211 |
+
ollama_pid = process.pid
|
| 212 |
+
print(f"📝 保存的PID: {ollama_pid}")
|
| 213 |
+
```
|
| 214 |
+
|
| 215 |
+
### 单元格4: 下载模型
|
| 216 |
+
|
| 217 |
+
```bash
|
| 218 |
+
%%bash
|
| 219 |
+
echo "📥 下载Mistral模型..."
|
| 220 |
+
ollama pull mistral
|
| 221 |
+
echo "✅ 模型下载完成"
|
| 222 |
+
ollama list
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
### 单元格5: 上传项目文件
|
| 226 |
+
|
| 227 |
+
```python
|
| 228 |
+
# 方式A: 从Google Drive
|
| 229 |
+
from google.colab import drive
|
| 230 |
+
drive.mount('/content/drive')
|
| 231 |
+
|
| 232 |
+
# 复制项目文件
|
| 233 |
+
!cp -r /content/drive/MyDrive/adaptive_RAG /content/
|
| 234 |
+
%cd /content/adaptive_RAG
|
| 235 |
+
|
| 236 |
+
# 方式B: 手动上传
|
| 237 |
+
# from google.colab import files
|
| 238 |
+
# uploaded = files.upload()
|
| 239 |
+
```
|
| 240 |
+
|
| 241 |
+
### 单元格6: 安装依赖
|
| 242 |
+
|
| 243 |
+
```bash
|
| 244 |
+
%%bash
|
| 245 |
+
pip install -q -r requirements.txt
|
| 246 |
+
pip install -q -r requirements_graphrag.txt
|
| 247 |
+
echo "✅ 依赖安装完成"
|
| 248 |
+
```
|
| 249 |
+
|
| 250 |
+
### 单元格7: 配置环境
|
| 251 |
+
|
| 252 |
+
```python
|
| 253 |
+
import os
|
| 254 |
+
from getpass import getpass
|
| 255 |
+
|
| 256 |
+
# 设置API密钥
|
| 257 |
+
if not os.path.exists('.env'):
|
| 258 |
+
api_key = getpass('输入TAVILY_API_KEY: ')
|
| 259 |
+
with open('.env', 'w') as f:
|
| 260 |
+
f.write(f'TAVILY_API_KEY={api_key}\n')
|
| 261 |
+
print("✅ .env文件已创建")
|
| 262 |
+
else:
|
| 263 |
+
print("✅ 使用现有.env文件")
|
| 264 |
+
```
|
| 265 |
+
|
| 266 |
+
### 单元格8: 运行GraphRAG
|
| 267 |
+
|
| 268 |
+
```python
|
| 269 |
+
# 方式A: 直接运行
|
| 270 |
+
!python main_graphrag.py
|
| 271 |
+
|
| 272 |
+
# 方式B: 在Python中运行(可以捕获输出)
|
| 273 |
+
import subprocess
|
| 274 |
+
|
| 275 |
+
result = subprocess.run(
|
| 276 |
+
["python", "main_graphrag.py"],
|
| 277 |
+
capture_output=True,
|
| 278 |
+
text=True
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
print(result.stdout)
|
| 282 |
+
if result.returncode != 0:
|
| 283 |
+
print("错误信息:")
|
| 284 |
+
print(result.stderr)
|
| 285 |
+
```
|
| 286 |
+
|
| 287 |
+
### 单元格9: 下载结果
|
| 288 |
+
|
| 289 |
+
```python
|
| 290 |
+
# 下载生成的知识图谱
|
| 291 |
+
from google.colab import files
|
| 292 |
+
|
| 293 |
+
if os.path.exists('data/knowledge_graph.json'):
|
| 294 |
+
files.download('data/knowledge_graph.json')
|
| 295 |
+
print("✅ 文件已下载")
|
| 296 |
+
else:
|
| 297 |
+
print("❌ 未找到图谱文件")
|
| 298 |
+
|
| 299 |
+
# 保存到Google Drive
|
| 300 |
+
import shutil
|
| 301 |
+
shutil.copy(
|
| 302 |
+
'data/knowledge_graph.json',
|
| 303 |
+
'/content/drive/MyDrive/graphrag_backup.json'
|
| 304 |
+
)
|
| 305 |
+
print("✅ 已备份到Google Drive")
|
| 306 |
+
```
|
| 307 |
+
|
| 308 |
+
### 单元格10: 清理(可选)
|
| 309 |
+
|
| 310 |
+
```python
|
| 311 |
+
# 停止Ollama服务
|
| 312 |
+
import os
|
| 313 |
+
import signal
|
| 314 |
+
|
| 315 |
+
try:
|
| 316 |
+
os.kill(ollama_pid, signal.SIGTERM)
|
| 317 |
+
print(f"✅ Ollama服务已停止 (PID: {ollama_pid})")
|
| 318 |
+
except:
|
| 319 |
+
print("⚠️ 停止服务失败,手动停止:")
|
| 320 |
+
!pkill -f 'ollama serve'
|
| 321 |
+
```
|
| 322 |
+
|
| 323 |
+
---
|
| 324 |
+
|
| 325 |
+
## ⚠️ 常见问题
|
| 326 |
+
|
| 327 |
+
### Q1: Ollama服务启动后立即退出
|
| 328 |
+
|
| 329 |
+
**A**: 使用 `subprocess.Popen` 而不是 `subprocess.run`:
|
| 330 |
+
|
| 331 |
+
```python
|
| 332 |
+
# ❌ 错误方式
|
| 333 |
+
!ollama serve & # 会立即退出
|
| 334 |
+
|
| 335 |
+
# ✅ 正确方式
|
| 336 |
+
import subprocess
|
| 337 |
+
process = subprocess.Popen(["ollama", "serve"])
|
| 338 |
+
```
|
| 339 |
+
|
| 340 |
+
### Q2: 连接被拒绝 (Connection refused)
|
| 341 |
+
|
| 342 |
+
**A**: 等待服务完全启动:
|
| 343 |
+
|
| 344 |
+
```python
|
| 345 |
+
import time
|
| 346 |
+
time.sleep(10) # 增加等待时间
|
| 347 |
+
```
|
| 348 |
+
|
| 349 |
+
### Q3: 进程管理困难
|
| 350 |
+
|
| 351 |
+
**A**: 使用PID文件:
|
| 352 |
+
|
| 353 |
+
```python
|
| 354 |
+
# 保存PID
|
| 355 |
+
with open('/tmp/ollama.pid', 'w') as f:
|
| 356 |
+
f.write(str(process.pid))
|
| 357 |
+
|
| 358 |
+
# 后续停止
|
| 359 |
+
with open('/tmp/ollama.pid', 'r') as f:
|
| 360 |
+
pid = int(f.read())
|
| 361 |
+
os.kill(pid, signal.SIGTERM)
|
| 362 |
+
```
|
| 363 |
+
|
| 364 |
+
### Q4: 会话超时导致服务停止
|
| 365 |
+
|
| 366 |
+
**A**: 定期执行代码保持活跃:
|
| 367 |
+
|
| 368 |
+
```python
|
| 369 |
+
import time
|
| 370 |
+
while True:
|
| 371 |
+
print("Keep alive...")
|
| 372 |
+
time.sleep(300) # 每5分钟
|
| 373 |
+
```
|
| 374 |
+
|
| 375 |
+
---
|
| 376 |
+
|
| 377 |
+
## 📚 推荐的完整流程
|
| 378 |
+
|
| 379 |
+
1. ✅ **运行自动化脚本** - `!python colab_setup_and_run.py`
|
| 380 |
+
2. ✅ **或按照Notebook示例** - 逐步执行每个单元格
|
| 381 |
+
3. ✅ **定期保存结果** - 到Google Drive
|
| 382 |
+
|
| 383 |
+
---
|
| 384 |
+
|
| 385 |
+
## 💡 最佳实践
|
| 386 |
+
|
| 387 |
+
1. **始终保存Ollama的PID**: 方便后续管理
|
| 388 |
+
2. **使用try-finally**: 确保清理后台进程
|
| 389 |
+
3. **定期备份**: 保存中间结果到Drive
|
| 390 |
+
4. **监控显存**: 避免OOM错误
|
| 391 |
+
|
| 392 |
+
```python
|
| 393 |
+
# 最佳实践示例
|
| 394 |
+
import subprocess
|
| 395 |
+
import atexit
|
| 396 |
+
import signal
|
| 397 |
+
|
| 398 |
+
# 启动Ollama
|
| 399 |
+
ollama_process = subprocess.Popen(["ollama", "serve"])
|
| 400 |
+
|
| 401 |
+
# 注册清理函数
|
| 402 |
+
def cleanup():
|
| 403 |
+
try:
|
| 404 |
+
ollama_process.terminate()
|
| 405 |
+
print("✅ Ollama已停止")
|
| 406 |
+
except:
|
| 407 |
+
pass
|
| 408 |
+
|
| 409 |
+
atexit.register(cleanup)
|
| 410 |
+
|
| 411 |
+
# 运行您的代码
|
| 412 |
+
try:
|
| 413 |
+
# ... 您的GraphRAG代码 ...
|
| 414 |
+
pass
|
| 415 |
+
finally:
|
| 416 |
+
cleanup()
|
| 417 |
+
```
|
| 418 |
+
|
| 419 |
+
---
|
| 420 |
+
|
| 421 |
+
**推荐**: 直接使用 `colab_setup_and_run.py` 脚本,它已经处理了所有这些细节!🚀
|
DEPLOYMENT_GUIDE.md
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Linux GPU部署指南 (RTX 4090)
|
| 2 |
+
|
| 3 |
+
## 🚀 自适应RAG系统在Linux RTX 4090环境部署
|
| 4 |
+
|
| 5 |
+
本指南将详细介绍如何在配备NVIDIA RTX 4090 GPU的Linux服务器上部署自适应RAG系统。
|
| 6 |
+
|
| 7 |
+
## 📋 环境要求
|
| 8 |
+
|
| 9 |
+
### 硬件要求
|
| 10 |
+
- NVIDIA RTX 4090 GPU
|
| 11 |
+
- 至少16GB内存(推荐32GB)
|
| 12 |
+
- 50GB+可用磁盘空间
|
| 13 |
+
- Ubuntu 20.04+ / CentOS 8+ / RHEL 8+
|
| 14 |
+
|
| 15 |
+
### 软件要求
|
| 16 |
+
- Linux操作系统(推荐Ubuntu 22.04 LTS)
|
| 17 |
+
- NVIDIA驱动程序(推荐535+)
|
| 18 |
+
- CUDA 12.0+
|
| 19 |
+
- Docker(可选但推荐)
|
| 20 |
+
- Python 3.8-3.11
|
| 21 |
+
|
| 22 |
+
## 🔧 步骤1:系统准备
|
| 23 |
+
|
| 24 |
+
### 1.1 更新系统
|
| 25 |
+
```bash
|
| 26 |
+
sudo apt update && sudo apt upgrade -y
|
| 27 |
+
sudo apt install -y curl wget git build-essential python3-pip python3-venv
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
### 1.2 安装NVIDIA驱动和CUDA
|
| 31 |
+
```bash
|
| 32 |
+
# 检查GPU
|
| 33 |
+
lspci | grep -i nvidia
|
| 34 |
+
|
| 35 |
+
# 添加NVIDIA软件源
|
| 36 |
+
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb
|
| 37 |
+
sudo dpkg -i cuda-keyring_1.0-1_all.deb
|
| 38 |
+
sudo apt-get update
|
| 39 |
+
|
| 40 |
+
# 安装NVIDIA驱动和CUDA
|
| 41 |
+
sudo apt-get install -y nvidia-driver-535 cuda-12-2
|
| 42 |
+
|
| 43 |
+
# 重启系统
|
| 44 |
+
sudo reboot
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
### 1.3 验证GPU安装
|
| 48 |
+
```bash
|
| 49 |
+
# 重启后验证
|
| 50 |
+
nvidia-smi
|
| 51 |
+
nvcc --version
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
## 🐳 步骤2:Docker环境配置(推荐)
|
| 55 |
+
|
| 56 |
+
### 2.1 安装Docker
|
| 57 |
+
```bash
|
| 58 |
+
# 安装Docker
|
| 59 |
+
curl -fsSL https://get.docker.com -o get-docker.sh
|
| 60 |
+
sudo sh get-docker.sh
|
| 61 |
+
sudo usermod -aG docker $USER
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
### 2.2 安装NVIDIA Container Toolkit
|
| 65 |
+
```bash
|
| 66 |
+
# 添加NVIDIA Docker源
|
| 67 |
+
distribution=$(. /etc/os-release;echo $ID$VERSION_ID)
|
| 68 |
+
curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add -
|
| 69 |
+
curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list
|
| 70 |
+
|
| 71 |
+
# 安装nvidia-container-toolkit
|
| 72 |
+
sudo apt-get update
|
| 73 |
+
sudo apt-get install -y nvidia-container-toolkit
|
| 74 |
+
sudo systemctl restart docker
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
### 2.3 创建Dockerfile
|
| 78 |
+
```dockerfile
|
| 79 |
+
# 创建 Dockerfile
|
| 80 |
+
cat > Dockerfile << 'EOF'
|
| 81 |
+
FROM nvidia/cuda:12.2-devel-ubuntu22.04
|
| 82 |
+
|
| 83 |
+
# 设置非交互模式
|
| 84 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 85 |
+
ENV PYTHONUNBUFFERED=1
|
| 86 |
+
|
| 87 |
+
# 更新系统并安装Python
|
| 88 |
+
RUN apt-get update && apt-get install -y \
|
| 89 |
+
python3 \
|
| 90 |
+
python3-pip \
|
| 91 |
+
python3-venv \
|
| 92 |
+
git \
|
| 93 |
+
curl \
|
| 94 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 95 |
+
|
| 96 |
+
# 创建工作目录
|
| 97 |
+
WORKDIR /app
|
| 98 |
+
|
| 99 |
+
# 复制项目文件
|
| 100 |
+
COPY requirements.txt .
|
| 101 |
+
COPY *.py .
|
| 102 |
+
COPY *.md .
|
| 103 |
+
|
| 104 |
+
# 安装Python依赖
|
| 105 |
+
RUN pip3 install --no-cache-dir -r requirements.txt
|
| 106 |
+
|
| 107 |
+
# 暴露端口(如果需要Web界面)
|
| 108 |
+
EXPOSE 8000
|
| 109 |
+
|
| 110 |
+
# 启动命令
|
| 111 |
+
CMD ["python3", "main.py"]
|
| 112 |
+
EOF
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
## 🐍 步骤3:Python环境配置(直接部署)
|
| 116 |
+
|
| 117 |
+
### 3.1 创建Python虚拟环境
|
| 118 |
+
```bash
|
| 119 |
+
# 克隆项目
|
| 120 |
+
git clone <your-repo-url> adaptive_rag
|
| 121 |
+
cd adaptive_rag
|
| 122 |
+
|
| 123 |
+
# 创建虚拟环境
|
| 124 |
+
python3 -m venv rag_env
|
| 125 |
+
source rag_env/bin/activate
|
| 126 |
+
|
| 127 |
+
# 升级pip
|
| 128 |
+
pip install --upgrade pip
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
### 3.2 修改requirements.txt以支持GPU
|
| 132 |
+
需要更新requirements.txt以优化GPU使用:
|
| 133 |
+
|
| 134 |
+
```bash
|
| 135 |
+
# 创建GPU优化的requirements文件
|
| 136 |
+
cat > requirements_gpu.txt << 'EOF'
|
| 137 |
+
# 核心框架
|
| 138 |
+
langchain>=0.1.0
|
| 139 |
+
langgraph>=0.0.40
|
| 140 |
+
langchain-community>=0.0.20
|
| 141 |
+
langchain-core>=0.1.0
|
| 142 |
+
|
| 143 |
+
# LLM集成
|
| 144 |
+
langchain-ollama>=0.1.0
|
| 145 |
+
|
| 146 |
+
# 向量数据库和嵌入(GPU优化版本)
|
| 147 |
+
chromadb>=0.4.0
|
| 148 |
+
sentence-transformers>=2.2.0
|
| 149 |
+
torch>=2.0.0+cu118 --index-url https://download.pytorch.org/whl/cu118
|
| 150 |
+
torchvision>=0.15.0+cu118 --index-url https://download.pytorch.org/whl/cu118
|
| 151 |
+
transformers>=4.30.0
|
| 152 |
+
accelerate>=0.20.0
|
| 153 |
+
|
| 154 |
+
# 文档处理
|
| 155 |
+
tiktoken>=0.5.0
|
| 156 |
+
beautifulsoup4>=4.12.0
|
| 157 |
+
requests>=2.31.0
|
| 158 |
+
|
| 159 |
+
# 网络搜索
|
| 160 |
+
tavily-python>=0.3.0
|
| 161 |
+
|
| 162 |
+
# 数据处理
|
| 163 |
+
numpy>=1.24.0,<2.0
|
| 164 |
+
pandas>=2.0.0
|
| 165 |
+
|
| 166 |
+
# 工具库
|
| 167 |
+
python-dotenv>=1.0.0
|
| 168 |
+
pydantic>=2.0.0
|
| 169 |
+
typing-extensions>=4.0.0
|
| 170 |
+
|
| 171 |
+
# GPU加速库
|
| 172 |
+
cupy-cuda12x>=12.0.0
|
| 173 |
+
faiss-gpu>=1.7.4
|
| 174 |
+
EOF
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
### 3.3 安装依赖
|
| 178 |
+
```bash
|
| 179 |
+
# 安装GPU优化依赖
|
| 180 |
+
pip install -r requirements_gpu.txt
|
| 181 |
+
```
|
| 182 |
+
|
| 183 |
+
## 🛠️ 步骤4:修改配置以优化GPU使用
|
| 184 |
+
|
| 185 |
+
### 4.1 更新document_processor.py以使用GPU
|
| 186 |
+
需要修改嵌入模型配置:
|
| 187 |
+
|
| 188 |
+
```python
|
| 189 |
+
# 在document_processor.py中修改
|
| 190 |
+
self.embeddings = HuggingFaceEmbeddings(
|
| 191 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
| 192 |
+
model_kwargs={'device': 'cuda'}, # 使用GPU
|
| 193 |
+
encode_kwargs={'normalize_embeddings': True}
|
| 194 |
+
)
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
### 4.2 创建GPU优化配置
|
| 198 |
+
```python
|
| 199 |
+
# 创建 gpu_config.py
|
| 200 |
+
cat > gpu_config.py << 'EOF'
|
| 201 |
+
import torch
|
| 202 |
+
import os
|
| 203 |
+
|
| 204 |
+
# GPU配置
|
| 205 |
+
if torch.cuda.is_available():
|
| 206 |
+
DEVICE = "cuda"
|
| 207 |
+
GPU_COUNT = torch.cuda.device_count()
|
| 208 |
+
GPU_NAME = torch.cuda.get_device_name(0)
|
| 209 |
+
print(f"发现 {GPU_COUNT} 个GPU: {GPU_NAME}")
|
| 210 |
+
|
| 211 |
+
# 设置CUDA优化
|
| 212 |
+
torch.backends.cudnn.benchmark = True
|
| 213 |
+
torch.backends.cudnn.deterministic = False
|
| 214 |
+
|
| 215 |
+
# 设置GPU内存管理
|
| 216 |
+
torch.cuda.empty_cache()
|
| 217 |
+
else:
|
| 218 |
+
DEVICE = "cpu"
|
| 219 |
+
print("未发现GPU,使用CPU模式")
|
| 220 |
+
|
| 221 |
+
# 优化设置
|
| 222 |
+
EMBEDDING_BATCH_SIZE = 32 if DEVICE == "cuda" else 8
|
| 223 |
+
MAX_WORKERS = 4 if DEVICE == "cuda" else 2
|
| 224 |
+
EOF
|
| 225 |
+
```
|
| 226 |
+
|
| 227 |
+
## 🤖 步骤5:安装和配置Ollama
|
| 228 |
+
|
| 229 |
+
### 5.1 安装Ollama
|
| 230 |
+
```bash
|
| 231 |
+
# 下载并安装Ollama
|
| 232 |
+
curl -fsSL https://ollama.ai/install.sh | sh
|
| 233 |
+
|
| 234 |
+
# 或者使用Docker
|
| 235 |
+
# docker run -d --gpus=all -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
### 5.2 下载模型
|
| 239 |
+
```bash
|
| 240 |
+
# 下载Mistral模型
|
| 241 |
+
ollama pull mistral
|
| 242 |
+
|
| 243 |
+
# 或者下载更大的模型(如果GPU内存足够)
|
| 244 |
+
ollama pull llama2:13b
|
| 245 |
+
ollama pull codellama:34b
|
| 246 |
+
```
|
| 247 |
+
|
| 248 |
+
### 5.3 启动Ollama服务
|
| 249 |
+
```bash
|
| 250 |
+
# 启动Ollama服务
|
| 251 |
+
ollama serve &
|
| 252 |
+
|
| 253 |
+
# 验证服务
|
| 254 |
+
curl http://localhost:11434/api/version
|
| 255 |
+
```
|
| 256 |
+
|
| 257 |
+
## 🔐 步骤6:环境变量配置
|
| 258 |
+
|
| 259 |
+
### 6.1 创建.env文件
|
| 260 |
+
```bash
|
| 261 |
+
cat > .env << 'EOF'
|
| 262 |
+
# API密钥
|
| 263 |
+
TAVILY_API_KEY=your_tavily_api_key_here
|
| 264 |
+
|
| 265 |
+
# GPU配置
|
| 266 |
+
CUDA_VISIBLE_DEVICES=0
|
| 267 |
+
TORCH_CUDA_ARCH_LIST="8.9" # RTX 4090架构
|
| 268 |
+
|
| 269 |
+
# 模型配置
|
| 270 |
+
HF_HOME=/app/models
|
| 271 |
+
TRANSFORMERS_CACHE=/app/models
|
| 272 |
+
|
| 273 |
+
# 性能优化
|
| 274 |
+
OMP_NUM_THREADS=8
|
| 275 |
+
MKL_NUM_THREADS=8
|
| 276 |
+
EOF
|
| 277 |
+
```
|
| 278 |
+
|
| 279 |
+
## 🚀 步骤7:部署和启动
|
| 280 |
+
|
| 281 |
+
### 7.1 使用Docker部署
|
| 282 |
+
```bash
|
| 283 |
+
# 构建镜像
|
| 284 |
+
docker build -t adaptive-rag:gpu .
|
| 285 |
+
|
| 286 |
+
# 运行容器
|
| 287 |
+
docker run -d \
|
| 288 |
+
--gpus all \
|
| 289 |
+
--name adaptive-rag \
|
| 290 |
+
--env-file .env \
|
| 291 |
+
-p 8000:8000 \
|
| 292 |
+
-v $(pwd)/data:/app/data \
|
| 293 |
+
adaptive-rag:gpu
|
| 294 |
+
```
|
| 295 |
+
|
| 296 |
+
### 7.2 直接Python部署
|
| 297 |
+
```bash
|
| 298 |
+
# 激活虚拟环境
|
| 299 |
+
source rag_env/bin/activate
|
| 300 |
+
|
| 301 |
+
# 启动系统
|
| 302 |
+
python main.py
|
| 303 |
+
```
|
| 304 |
+
|
| 305 |
+
## 📊 步骤8:性能监控
|
| 306 |
+
|
| 307 |
+
### 8.1 创建监控脚本
|
| 308 |
+
```bash
|
| 309 |
+
cat > monitor_gpu.py << 'EOF'
|
| 310 |
+
import psutil
|
| 311 |
+
import GPUtil
|
| 312 |
+
import time
|
| 313 |
+
|
| 314 |
+
def monitor_system():
|
| 315 |
+
while True:
|
| 316 |
+
# GPU监控
|
| 317 |
+
gpus = GPUtil.getGPUs()
|
| 318 |
+
for gpu in gpus:
|
| 319 |
+
print(f"GPU {gpu.id}: {gpu.load*100}% | 内存: {gpu.memoryUsed}MB/{gpu.memoryTotal}MB")
|
| 320 |
+
|
| 321 |
+
# CPU和内存监控
|
| 322 |
+
print(f"CPU: {psutil.cpu_percent()}% | 内存: {psutil.virtual_memory().percent}%")
|
| 323 |
+
print("-" * 50)
|
| 324 |
+
time.sleep(5)
|
| 325 |
+
|
| 326 |
+
if __name__ == "__main__":
|
| 327 |
+
monitor_system()
|
| 328 |
+
EOF
|
| 329 |
+
|
| 330 |
+
pip install gputil
|
| 331 |
+
python monitor_gpu.py
|
| 332 |
+
```
|
| 333 |
+
|
| 334 |
+
## 🔧 步骤9:性能优化配置
|
| 335 |
+
|
| 336 |
+
### 9.1 创建优化启动脚本
|
| 337 |
+
```bash
|
| 338 |
+
cat > start_optimized.sh << 'EOF'
|
| 339 |
+
#!/bin/bash
|
| 340 |
+
|
| 341 |
+
# 设置GPU优化环境变量
|
| 342 |
+
export CUDA_VISIBLE_DEVICES=0
|
| 343 |
+
export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512
|
| 344 |
+
export TOKENIZERS_PARALLELISM=false
|
| 345 |
+
|
| 346 |
+
# 启动系统
|
| 347 |
+
source rag_env/bin/activate
|
| 348 |
+
python main.py
|
| 349 |
+
EOF
|
| 350 |
+
|
| 351 |
+
chmod +x start_optimized.sh
|
| 352 |
+
```
|
| 353 |
+
|
| 354 |
+
### 9.2 创建系统服务
|
| 355 |
+
```bash
|
| 356 |
+
# 创建systemd服务
|
| 357 |
+
sudo tee /etc/systemd/system/adaptive-rag.service > /dev/null << 'EOF'
|
| 358 |
+
[Unit]
|
| 359 |
+
Description=Adaptive RAG System
|
| 360 |
+
After=network.target
|
| 361 |
+
|
| 362 |
+
[Service]
|
| 363 |
+
Type=simple
|
| 364 |
+
User=your_username
|
| 365 |
+
WorkingDirectory=/path/to/adaptive_rag
|
| 366 |
+
Environment=PATH=/path/to/adaptive_rag/rag_env/bin
|
| 367 |
+
ExecStart=/path/to/adaptive_rag/rag_env/bin/python main.py
|
| 368 |
+
Restart=always
|
| 369 |
+
RestartSec=10
|
| 370 |
+
|
| 371 |
+
[Install]
|
| 372 |
+
WantedBy=multi-user.target
|
| 373 |
+
EOF
|
| 374 |
+
|
| 375 |
+
# 启用服务
|
| 376 |
+
sudo systemctl daemon-reload
|
| 377 |
+
sudo systemctl enable adaptive-rag
|
| 378 |
+
sudo systemctl start adaptive-rag
|
| 379 |
+
```
|
| 380 |
+
|
| 381 |
+
## 🐛 步骤10:故障排除
|
| 382 |
+
|
| 383 |
+
### 10.1 常见问题
|
| 384 |
+
|
| 385 |
+
1. **CUDA内存不足**
|
| 386 |
+
```bash
|
| 387 |
+
# 减少批处理大小
|
| 388 |
+
export EMBEDDING_BATCH_SIZE=16
|
| 389 |
+
# 或者启用梯度检查点
|
| 390 |
+
export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:256
|
| 391 |
+
```
|
| 392 |
+
|
| 393 |
+
2. **Ollama连接问题**
|
| 394 |
+
```bash
|
| 395 |
+
# 检查Ollama状态
|
| 396 |
+
sudo systemctl status ollama
|
| 397 |
+
# 重启Ollama
|
| 398 |
+
sudo systemctl restart ollama
|
| 399 |
+
```
|
| 400 |
+
|
| 401 |
+
3. **权限问题**
|
| 402 |
+
```bash
|
| 403 |
+
# 添加用户到docker组
|
| 404 |
+
sudo usermod -aG docker $USER
|
| 405 |
+
# 重新登录
|
| 406 |
+
```
|
| 407 |
+
|
| 408 |
+
### 10.2 性能调优
|
| 409 |
+
|
| 410 |
+
```bash
|
| 411 |
+
# GPU性能模式
|
| 412 |
+
sudo nvidia-smi -pm 1
|
| 413 |
+
sudo nvidia-smi -ac 9251,2100
|
| 414 |
+
|
| 415 |
+
# 系统优化
|
| 416 |
+
echo 'vm.swappiness=10' | sudo tee -a /etc/sysctl.conf
|
| 417 |
+
sudo sysctl -p
|
| 418 |
+
```
|
| 419 |
+
|
| 420 |
+
## 📈 预期性能
|
| 421 |
+
|
| 422 |
+
在RTX 4090环境下的预期性能:
|
| 423 |
+
- **文档嵌入**: ~1000 documents/second
|
| 424 |
+
- **查询响应**: ~2-5 seconds per query
|
| 425 |
+
- **GPU利用率**: 60-80%
|
| 426 |
+
- **内存使用**: 8-12GB GPU memory
|
| 427 |
+
|
| 428 |
+
## 🎯 验证部署
|
| 429 |
+
|
| 430 |
+
```bash
|
| 431 |
+
# 测试GPU可用性
|
| 432 |
+
python -c "import torch; print(f'CUDA可用: {torch.cuda.is_available()}'); print(f'GPU数量: {torch.cuda.device_count()}')"
|
| 433 |
+
|
| 434 |
+
# 测试系统
|
| 435 |
+
curl -X POST http://localhost:8000/query \
|
| 436 |
+
-H "Content-Type: application/json" \
|
| 437 |
+
-d '{"question": "什么是LLM智能体?"}'
|
| 438 |
+
```
|
| 439 |
+
|
| 440 |
+
这个部署指南提供了完整的Linux GPU环境配置,确保您的自适应RAG系统能够充分利用RTX 4090的计算能力。
|
Dockerfile.gpu
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GPU优化Dockerfile - 针对RTX 4090
|
| 2 |
+
FROM nvidia/cuda:12.2-devel-ubuntu22.04
|
| 3 |
+
|
| 4 |
+
# 设置非交互模式和环境变量
|
| 5 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 6 |
+
ENV PYTHONUNBUFFERED=1
|
| 7 |
+
ENV CUDA_VISIBLE_DEVICES=0
|
| 8 |
+
ENV PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512
|
| 9 |
+
|
| 10 |
+
# 更新系统并安装必要软件
|
| 11 |
+
RUN apt-get update && apt-get install -y \
|
| 12 |
+
python3 \
|
| 13 |
+
python3-pip \
|
| 14 |
+
python3-venv \
|
| 15 |
+
git \
|
| 16 |
+
curl \
|
| 17 |
+
wget \
|
| 18 |
+
build-essential \
|
| 19 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 20 |
+
|
| 21 |
+
# 创建应用目录
|
| 22 |
+
WORKDIR /app
|
| 23 |
+
|
| 24 |
+
# 创建必要目录
|
| 25 |
+
RUN mkdir -p /app/data /app/models /app/logs
|
| 26 |
+
|
| 27 |
+
# 复制依赖文件
|
| 28 |
+
COPY requirements_gpu.txt .
|
| 29 |
+
|
| 30 |
+
# 升级pip并安装Python依赖
|
| 31 |
+
RUN pip3 install --no-cache-dir --upgrade pip && \
|
| 32 |
+
pip3 install --no-cache-dir -r requirements_gpu.txt
|
| 33 |
+
|
| 34 |
+
# 复制应用文件
|
| 35 |
+
COPY *.py .
|
| 36 |
+
COPY *.md .
|
| 37 |
+
COPY .env.example .
|
| 38 |
+
|
| 39 |
+
# 设置Python路径
|
| 40 |
+
ENV PYTHONPATH=/app
|
| 41 |
+
|
| 42 |
+
# 创建启动脚本
|
| 43 |
+
RUN echo '#!/bin/bash\n\
|
| 44 |
+
export CUDA_VISIBLE_DEVICES=0\n\
|
| 45 |
+
export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512\n\
|
| 46 |
+
export TOKENIZERS_PARALLELISM=false\n\
|
| 47 |
+
python3 -c "import torch; print(f'"'"'CUDA可用: {torch.cuda.is_available()}'"'"'); print(f'"'"'GPU数量: {torch.cuda.device_count()}'"'"')"\n\
|
| 48 |
+
python3 main.py' > /app/start.sh && chmod +x /app/start.sh
|
| 49 |
+
|
| 50 |
+
# 暴露端口
|
| 51 |
+
EXPOSE 8000 8001
|
| 52 |
+
|
| 53 |
+
# 健康检查
|
| 54 |
+
HEALTHCHECK --interval=30s --timeout=30s --start-period=60s --retries=3 \
|
| 55 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 56 |
+
|
| 57 |
+
# 启动命令
|
| 58 |
+
CMD ["/app/start.sh"]
|
GRAPHRAG_GUIDE.md
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GraphRAG 集成指南
|
| 2 |
+
|
| 3 |
+
## 📋 概述
|
| 4 |
+
|
| 5 |
+
本项目已集成**Microsoft GraphRAG**架构,通过知识图谱增强传统向量检索,提供更精准的信息提取和推理能力。
|
| 6 |
+
|
| 7 |
+
## 🏗️ GraphRAG 架构
|
| 8 |
+
|
| 9 |
+
### 核心组件
|
| 10 |
+
|
| 11 |
+
```
|
| 12 |
+
文档集合
|
| 13 |
+
↓
|
| 14 |
+
┌─────────────────────────────────────┐
|
| 15 |
+
│ 实体和关系提取 (Entity Extraction) │
|
| 16 |
+
│ - 使用LLM识别实体 │
|
| 17 |
+
│ - 提取实体间关系 │
|
| 18 |
+
└─────────────────────────────────────┘
|
| 19 |
+
↓
|
| 20 |
+
┌─────────────────────────────────────┐
|
| 21 |
+
│ 知识图谱构建 (Graph Construction) │
|
| 22 |
+
│ - 实体去重 │
|
| 23 |
+
│ - 构建图结构 │
|
| 24 |
+
└─────────────────────────────────────┘
|
| 25 |
+
↓
|
| 26 |
+
┌─────────────────────────────────────┐
|
| 27 |
+
│ 社区检测 (Community Detection) │
|
| 28 |
+
│ - Louvain算法 │
|
| 29 |
+
│ - 层次化聚类 │
|
| 30 |
+
└─────────────────────────────────────┘
|
| 31 |
+
↓
|
| 32 |
+
┌─────────────────────────────────────┐
|
| 33 |
+
│ 社区摘要生成 (Community Summaries) │
|
| 34 |
+
│ - LLM生成摘要 │
|
| 35 |
+
│ - 多层次索引 │
|
| 36 |
+
└─────────────────────────────────────┘
|
| 37 |
+
↓
|
| 38 |
+
查询阶段
|
| 39 |
+
↓
|
| 40 |
+
┌──────────────┬──────────────┐
|
| 41 |
+
│ 本地查询 │ 全局查询 │
|
| 42 |
+
│ (Local Query)│(Global Query)│
|
| 43 |
+
│ │ │
|
| 44 |
+
│ 实体邻域检索 │ 社区摘要查询 │
|
| 45 |
+
└──────────────┴──────────────┘
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
## 📦 新增文件说明
|
| 49 |
+
|
| 50 |
+
### 1. **entity_extractor.py** - 实体提取器
|
| 51 |
+
```python
|
| 52 |
+
EntityExtractor
|
| 53 |
+
├── extract_entities() # 从文本提取实体
|
| 54 |
+
├── extract_relations() # 提取实体关系
|
| 55 |
+
└── extract_from_document() # 完整文档处理
|
| 56 |
+
|
| 57 |
+
EntityDeduplicator
|
| 58 |
+
└── deduplicate_entities() # 实体去重
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
**功能**:
|
| 62 |
+
- 使用LLM识别6种实体类型 (PERSON, ORGANIZATION, CONCEPT, TECHNOLOGY, PAPER, EVENT)
|
| 63 |
+
- 提取8种关系类型 (AUTHOR_OF, USES, BASED_ON, etc.)
|
| 64 |
+
- 智能实体去重和合并
|
| 65 |
+
|
| 66 |
+
### 2. **knowledge_graph.py** - 知识图谱核心
|
| 67 |
+
```python
|
| 68 |
+
KnowledgeGraph
|
| 69 |
+
├── add_entity() # 添加节点
|
| 70 |
+
├── add_relation() # 添加边
|
| 71 |
+
├── build_from_extractions() # 构建图谱
|
| 72 |
+
├── detect_communities() # 社区检测
|
| 73 |
+
├── get_community_members() # 获取社区成员
|
| 74 |
+
└── get_statistics() # 统计信息
|
| 75 |
+
|
| 76 |
+
CommunitySummarizer
|
| 77 |
+
├── summarize_community() # 单社区摘要
|
| 78 |
+
└── summarize_all_communities() # 全部社区摘要
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
**功能**:
|
| 82 |
+
- 基于NetworkX的图谱管理
|
| 83 |
+
- 支持3种社区检测算法 (Louvain, Greedy, Label Propagation)
|
| 84 |
+
- LLM驱动的社区摘要生成
|
| 85 |
+
- 图谱持久化存储
|
| 86 |
+
|
| 87 |
+
### 3. **graph_indexer.py** - 索引构建器
|
| 88 |
+
```python
|
| 89 |
+
GraphRAGIndexer
|
| 90 |
+
├── index_documents() # 构建索引
|
| 91 |
+
├── get_graph() # 获取图谱
|
| 92 |
+
└── load_index() # 加载索引
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
**流程**:
|
| 96 |
+
1. 批量实体提取
|
| 97 |
+
2. 实体去重合并
|
| 98 |
+
3. 构建知识图谱
|
| 99 |
+
4. 社区检测
|
| 100 |
+
5. 生成摘要
|
| 101 |
+
|
| 102 |
+
### 4. **graph_retriever.py** - 图谱检索器
|
| 103 |
+
```python
|
| 104 |
+
GraphRetriever
|
| 105 |
+
├── recognize_entities() # 识别问题中的实体
|
| 106 |
+
├── local_query() # 本地查询
|
| 107 |
+
├── global_query() # 全局查询
|
| 108 |
+
├── hybrid_query() # 混合查询
|
| 109 |
+
└── smart_query() # 智能查询
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
**查询模式**:
|
| 113 |
+
- **本地查询**: 针对特定实体的详细问题
|
| 114 |
+
- **全局查询**: 需要整体理解的概括性问题
|
| 115 |
+
- **智能查询**: 自动选择最佳策略
|
| 116 |
+
|
| 117 |
+
### 5. **main_graphrag.py** - GraphRAG集成示例
|
| 118 |
+
完整的使用示例和交互式界面
|
| 119 |
+
|
| 120 |
+
### 6. **requirements_graphrag.txt** - 额外依赖
|
| 121 |
+
GraphRAG所需的图处理库
|
| 122 |
+
|
| 123 |
+
## 🚀 快速开始
|
| 124 |
+
|
| 125 |
+
### 安装依赖
|
| 126 |
+
|
| 127 |
+
```bash
|
| 128 |
+
# 安装基础依赖
|
| 129 |
+
pip install -r requirements.txt
|
| 130 |
+
|
| 131 |
+
# 安装GraphRAG依赖
|
| 132 |
+
pip install -r requirements_graphrag.txt
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
### 首次使用
|
| 136 |
+
|
| 137 |
+
```python
|
| 138 |
+
# 方式1: 使用集成示例
|
| 139 |
+
python main_graphrag.py
|
| 140 |
+
|
| 141 |
+
# 方式2: 在代码中集成
|
| 142 |
+
from config import setup_environment
|
| 143 |
+
from document_processor import initialize_document_processor
|
| 144 |
+
from graph_indexer import initialize_graph_indexer
|
| 145 |
+
from graph_retriever import initialize_graph_retriever
|
| 146 |
+
|
| 147 |
+
# 初始化
|
| 148 |
+
setup_environment()
|
| 149 |
+
processor, vectorstore, retriever, doc_splits = initialize_document_processor()
|
| 150 |
+
|
| 151 |
+
# 构建GraphRAG索引
|
| 152 |
+
graph_indexer = initialize_graph_indexer()
|
| 153 |
+
knowledge_graph = graph_indexer.index_documents(
|
| 154 |
+
documents=doc_splits,
|
| 155 |
+
save_path="./data/knowledge_graph.json"
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# 初始化检索器
|
| 159 |
+
graph_retriever = initialize_graph_retriever(knowledge_graph)
|
| 160 |
+
|
| 161 |
+
# 查询
|
| 162 |
+
answer = graph_retriever.smart_query("LLM Agent的核心组件是什么?")
|
| 163 |
+
print(answer)
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
## 🔧 配置说明
|
| 167 |
+
|
| 168 |
+
在 `config.py` 中添加了以下配置:
|
| 169 |
+
|
| 170 |
+
```python
|
| 171 |
+
# GraphRAG配置
|
| 172 |
+
ENABLE_GRAPHRAG = True # 是否启用GraphRAG
|
| 173 |
+
GRAPHRAG_INDEX_PATH = "./data/knowledge_graph.json" # 图谱存储路径
|
| 174 |
+
GRAPHRAG_COMMUNITY_ALGORITHM = "louvain" # 社区检测算法
|
| 175 |
+
GRAPHRAG_MAX_HOPS = 2 # 本地查询最大跳数
|
| 176 |
+
GRAPHRAG_TOP_K_COMMUNITIES = 5 # 全局查询使用的社区数
|
| 177 |
+
GRAPHRAG_BATCH_SIZE = 10 # 实体提取批大小
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
## 📊 使用场景对比
|
| 181 |
+
|
| 182 |
+
### 传统向量检索 vs GraphRAG
|
| 183 |
+
|
| 184 |
+
| 场景 | 向量检索 | GraphRAG | 推荐 |
|
| 185 |
+
|-----|---------|----------|------|
|
| 186 |
+
| "AlphaCodium的作者是谁?" | ⚠️ 可能找到但不精确 | ✅ 直接查询实体关系 | GraphRAG本地查询 |
|
| 187 |
+
| "这些文档讨论什么主题?" | ⚠️ 需要读取多个片段 | ✅ 社区摘要直接回答 | GraphRAG全局查询 |
|
| 188 |
+
| "提示工程的应用场景" | ✅ 语义匹配效果好 | ✅ 可追踪关系链 | 混合查询 |
|
| 189 |
+
| "最新技术发展" | ✅ 适合模糊查询 | ❌ 需要明确实体 | 向量检索 |
|
| 190 |
+
|
| 191 |
+
## 🎯 查询策略选择
|
| 192 |
+
|
| 193 |
+
### 本地查询 (Local Query)
|
| 194 |
+
**适用**: 针对特定实体的详细问题
|
| 195 |
+
|
| 196 |
+
```python
|
| 197 |
+
# 示例问题
|
| 198 |
+
"LLM Agent包含哪些组件?"
|
| 199 |
+
"Transformer模型的作者是谁?"
|
| 200 |
+
"AlphaCodium使用了什么技术?"
|
| 201 |
+
|
| 202 |
+
# 代码
|
| 203 |
+
answer = graph_retriever.local_query(question, max_hops=2)
|
| 204 |
+
```
|
| 205 |
+
|
| 206 |
+
**工作原理**:
|
| 207 |
+
1. 识别问题中的实体
|
| 208 |
+
2. 扩展到邻居节点(支持多跳)
|
| 209 |
+
3. 收集实体信息和关系
|
| 210 |
+
4. 基于子图生成答案
|
| 211 |
+
|
| 212 |
+
### 全局查询 (Global Query)
|
| 213 |
+
**适用**: 需要整体视角的概括性问题
|
| 214 |
+
|
| 215 |
+
```python
|
| 216 |
+
# 示例问题
|
| 217 |
+
"这些文档的主要主题是什么?"
|
| 218 |
+
"涵盖了哪些研究领域?"
|
| 219 |
+
"关键的技术趋势有哪些?"
|
| 220 |
+
|
| 221 |
+
# 代码
|
| 222 |
+
answer = graph_retriever.global_query(question, top_k_communities=5)
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
**工作原理**:
|
| 226 |
+
1. 获取社区摘要
|
| 227 |
+
2. 基于摘要理解全局结构
|
| 228 |
+
3. 综合多个社区的信息
|
| 229 |
+
4. 生成高层次答案
|
| 230 |
+
|
| 231 |
+
### 智能查询 (Smart Query)
|
| 232 |
+
**适用**: 自动选择最佳策略
|
| 233 |
+
|
| 234 |
+
```python
|
| 235 |
+
# 自动判断使用本地还是全局查询
|
| 236 |
+
answer = graph_retriever.smart_query(question)
|
| 237 |
+
```
|
| 238 |
+
|
| 239 |
+
**决策逻辑**:
|
| 240 |
+
- 包含具体实体名称 → 本地查询
|
| 241 |
+
- 包含"主要"、"总体"、"概述"等关键词 → 全局查询
|
| 242 |
+
- 默认 → 本地查询
|
| 243 |
+
|
| 244 |
+
### 混合查询 (Hybrid Query)
|
| 245 |
+
**适用**: 需要多种视角的复杂问题
|
| 246 |
+
|
| 247 |
+
```python
|
| 248 |
+
result = graph_retriever.hybrid_query(question)
|
| 249 |
+
# 返回: {"local": "...", "global": "..."}
|
| 250 |
+
```
|
| 251 |
+
|
| 252 |
+
## 📈 性能优化
|
| 253 |
+
|
| 254 |
+
### 索引构建优化
|
| 255 |
+
|
| 256 |
+
```python
|
| 257 |
+
# 1. 批处理大小
|
| 258 |
+
graph_indexer.index_documents(
|
| 259 |
+
documents=doc_splits,
|
| 260 |
+
batch_size=20 # 增大批处理提高速度
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# 2. 增量索引(开发中)
|
| 264 |
+
# 避免每次重建整个图谱
|
| 265 |
+
|
| 266 |
+
# 3. 缓存已有索引
|
| 267 |
+
if os.path.exists(GRAPHRAG_INDEX_PATH):
|
| 268 |
+
knowledge_graph = graph_indexer.load_index(GRAPHRAG_INDEX_PATH)
|
| 269 |
+
```
|
| 270 |
+
|
| 271 |
+
### 查询优化
|
| 272 |
+
|
| 273 |
+
```python
|
| 274 |
+
# 1. 调整跳数
|
| 275 |
+
answer = graph_retriever.local_query(question, max_hops=1) # 减少跳数提速
|
| 276 |
+
|
| 277 |
+
# 2. 限制社区数量
|
| 278 |
+
answer = graph_retriever.global_query(question, top_k_communities=3) # 减少社区数
|
| 279 |
+
|
| 280 |
+
# 3. 实体识别缓存(开发中)
|
| 281 |
+
```
|
| 282 |
+
|
| 283 |
+
## 🔍 调试和可视化
|
| 284 |
+
|
| 285 |
+
### 查看图谱统计
|
| 286 |
+
|
| 287 |
+
```python
|
| 288 |
+
stats = knowledge_graph.get_statistics()
|
| 289 |
+
print(f"节点数: {stats['num_nodes']}")
|
| 290 |
+
print(f"边数: {stats['num_edges']}")
|
| 291 |
+
print(f"社区数: {stats['num_communities']}")
|
| 292 |
+
```
|
| 293 |
+
|
| 294 |
+
### 导出图谱
|
| 295 |
+
|
| 296 |
+
```python
|
| 297 |
+
# 保存为JSON
|
| 298 |
+
knowledge_graph.save_to_file("my_graph.json")
|
| 299 |
+
|
| 300 |
+
# 加载图谱
|
| 301 |
+
knowledge_graph.load_from_file("my_graph.json")
|
| 302 |
+
```
|
| 303 |
+
|
| 304 |
+
### 可视化(可选)
|
| 305 |
+
|
| 306 |
+
```python
|
| 307 |
+
# 需要额外安装: pip install pyvis
|
| 308 |
+
from pyvis.network import Network
|
| 309 |
+
|
| 310 |
+
def visualize_graph(kg, output="graph.html"):
|
| 311 |
+
net = Network(height="750px", width="100%", bgcolor="#222222", font_color="white")
|
| 312 |
+
|
| 313 |
+
for node, data in kg.graph.nodes(data=True):
|
| 314 |
+
net.add_node(node, label=node, title=data.get('description', ''))
|
| 315 |
+
|
| 316 |
+
for u, v, data in kg.graph.edges(data=True):
|
| 317 |
+
net.add_edge(u, v, title=data.get('relation_type', ''))
|
| 318 |
+
|
| 319 |
+
net.show(output)
|
| 320 |
+
print(f"图谱已保存到: {output}")
|
| 321 |
+
```
|
| 322 |
+
|
| 323 |
+
## ⚠️ 常见问题
|
| 324 |
+
|
| 325 |
+
### Q1: 实体提取质量不高?
|
| 326 |
+
**A**:
|
| 327 |
+
- 调整LLM温度参数
|
| 328 |
+
- 优化实体提取提示词
|
| 329 |
+
- 使用更强大的LLM模型
|
| 330 |
+
|
| 331 |
+
### Q2: 索引构建时间长?
|
| 332 |
+
**A**:
|
| 333 |
+
- 增大批处理大小
|
| 334 |
+
- 减少文档数量进行测试
|
| 335 |
+
- 使用缓存的索引文件
|
| 336 |
+
|
| 337 |
+
### Q3: 查询结果不相关?
|
| 338 |
+
**A**:
|
| 339 |
+
- 检查实体识别是否准确
|
| 340 |
+
- 调整查询策略(本地/全局)
|
| 341 |
+
- 增加邻居跳数
|
| 342 |
+
|
| 343 |
+
### Q4: 内存占用过大?
|
| 344 |
+
**A**:
|
| 345 |
+
- 使用更轻量的图数据库
|
| 346 |
+
- 分批处理大文档集
|
| 347 |
+
- 限制社区检测的迭代次数
|
| 348 |
+
|
| 349 |
+
## 🔄 与现有系统集成
|
| 350 |
+
|
| 351 |
+
### 修改现有 main.py
|
| 352 |
+
|
| 353 |
+
```python
|
| 354 |
+
from config import ENABLE_GRAPHRAG
|
| 355 |
+
from graph_indexer import initialize_graph_indexer
|
| 356 |
+
from graph_retriever import initialize_graph_retriever
|
| 357 |
+
|
| 358 |
+
class AdaptiveRAGSystem:
|
| 359 |
+
def __init__(self):
|
| 360 |
+
# ... 现有初始化代码 ...
|
| 361 |
+
|
| 362 |
+
# 添加GraphRAG支持
|
| 363 |
+
if ENABLE_GRAPHRAG:
|
| 364 |
+
self._setup_graphrag()
|
| 365 |
+
|
| 366 |
+
def _setup_graphrag(self):
|
| 367 |
+
self.graph_indexer = initialize_graph_indexer()
|
| 368 |
+
# ... 索引构建 ...
|
| 369 |
+
self.graph_retriever = initialize_graph_retriever(self.knowledge_graph)
|
| 370 |
+
|
| 371 |
+
def query(self, question: str):
|
| 372 |
+
# 混合使用向量检索和图谱查询
|
| 373 |
+
vector_docs = self.retriever.get_relevant_documents(question)
|
| 374 |
+
|
| 375 |
+
if ENABLE_GRAPHRAG:
|
| 376 |
+
graph_answer = self.graph_retriever.smart_query(question)
|
| 377 |
+
# 融合两种结果
|
| 378 |
+
return self._merge_results(vector_docs, graph_answer)
|
| 379 |
+
|
| 380 |
+
return self._generate_from_docs(vector_docs)
|
| 381 |
+
```
|
| 382 |
+
|
| 383 |
+
## 📚 参考资料
|
| 384 |
+
|
| 385 |
+
- [Microsoft GraphRAG 论文](https://arxiv.org/abs/2404.16130)
|
| 386 |
+
- [NetworkX 文档](https://networkx.org/)
|
| 387 |
+
- [Louvain 社区检测算法](https://en.wikipedia.org/wiki/Louvain_method)
|
| 388 |
+
|
| 389 |
+
## 🛣️ 未来增强
|
| 390 |
+
|
| 391 |
+
- [ ] 增量索引更新
|
| 392 |
+
- [ ] 多模态知识图谱
|
| 393 |
+
- [ ] 图谱可视化界面
|
| 394 |
+
- [ ] Neo4j集成(生产环境)
|
| 395 |
+
- [ ] 知识图谱推理引擎
|
| 396 |
+
- [ ] 实体链接优化
|
| 397 |
+
- [ ] 自动实体消歧
|
| 398 |
+
|
| 399 |
+
---
|
| 400 |
+
|
| 401 |
+
**提示**: 首次使用建议先在小数据集上测试,验证效果后再应用到完整数据集。
|
GRAPHRAG_INTEGRATION_SUMMARY.md
ADDED
|
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GraphRAG 集成完成总结
|
| 2 |
+
|
| 3 |
+
## ✅ 已完成的工作
|
| 4 |
+
|
| 5 |
+
### 🆕 新增文件 (7个)
|
| 6 |
+
|
| 7 |
+
| 文件 | 行数 | 主要功能 |
|
| 8 |
+
|------|------|---------|
|
| 9 |
+
| **entity_extractor.py** | 225 | 实体和关系提取、实体去重 |
|
| 10 |
+
| **knowledge_graph.py** | 348 | 图谱构建、社区检测、摘要生成 |
|
| 11 |
+
| **graph_indexer.py** | 146 | GraphRAG索引构建流程 |
|
| 12 |
+
| **graph_retriever.py** | 276 | 本地/全局/智能查询 |
|
| 13 |
+
| **main_graphrag.py** | 294 | 完整使用示例和交互界面 |
|
| 14 |
+
| **requirements_graphrag.txt** | 32 | GraphRAG额外依赖 |
|
| 15 |
+
| **GRAPHRAG_GUIDE.md** | 402 | 详细使用指南 |
|
| 16 |
+
|
| 17 |
+
### 🔧 修改的文件 (3个)
|
| 18 |
+
|
| 19 |
+
| 文件 | 修改内容 |
|
| 20 |
+
|------|---------|
|
| 21 |
+
| **config.py** | 添加7个GraphRAG配置参数 |
|
| 22 |
+
| **document_processor.py** | 修改`setup_knowledge_base()`返回doc_splits |
|
| 23 |
+
| **requirements.txt** | 添加networkx和python-louvain依赖 |
|
| 24 |
+
|
| 25 |
+
---
|
| 26 |
+
|
| 27 |
+
## 📋 文件修改详情
|
| 28 |
+
|
| 29 |
+
### 1. config.py - 新增配置
|
| 30 |
+
|
| 31 |
+
```python
|
| 32 |
+
# GraphRAG配置
|
| 33 |
+
ENABLE_GRAPHRAG = True
|
| 34 |
+
GRAPHRAG_INDEX_PATH = "./data/knowledge_graph.json"
|
| 35 |
+
GRAPHRAG_COMMUNITY_ALGORITHM = "louvain"
|
| 36 |
+
GRAPHRAG_MAX_HOPS = 2
|
| 37 |
+
GRAPHRAG_TOP_K_COMMUNITIES = 5
|
| 38 |
+
GRAPHRAG_BATCH_SIZE = 10
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### 2. document_processor.py - 函数修改
|
| 42 |
+
|
| 43 |
+
```python
|
| 44 |
+
# 修改前
|
| 45 |
+
def setup_knowledge_base(self, urls=None):
|
| 46 |
+
...
|
| 47 |
+
return vectorstore, retriever
|
| 48 |
+
|
| 49 |
+
# 修改后
|
| 50 |
+
def setup_knowledge_base(self, urls=None, enable_graphrag=False):
|
| 51 |
+
...
|
| 52 |
+
return vectorstore, retriever, doc_splits # 新增返回doc_splits
|
| 53 |
+
|
| 54 |
+
# 同步修改
|
| 55 |
+
def initialize_document_processor():
|
| 56 |
+
...
|
| 57 |
+
return processor, vectorstore, retriever, doc_splits # 新增doc_splits
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
### 3. requirements.txt - 新增依赖
|
| 61 |
+
|
| 62 |
+
```txt
|
| 63 |
+
# GraphRAG相关(可选)
|
| 64 |
+
networkx>=3.1
|
| 65 |
+
python-louvain>=0.16
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
---
|
| 69 |
+
|
| 70 |
+
## 🏗️ GraphRAG 架构概览
|
| 71 |
+
|
| 72 |
+
```
|
| 73 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 74 |
+
│ 文档处理层 │
|
| 75 |
+
│ document_processor.py → doc_splits │
|
| 76 |
+
└────────────────────────┬────────────────────────────────────┘
|
| 77 |
+
↓
|
| 78 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 79 |
+
│ 实体提取层 │
|
| 80 |
+
│ entity_extractor.py │
|
| 81 |
+
│ ├── EntityExtractor (实体和关系提取) │
|
| 82 |
+
│ └── EntityDeduplicator (实体去重) │
|
| 83 |
+
└────────────────────────┬────────────────────────────────────┘
|
| 84 |
+
↓
|
| 85 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 86 |
+
│ 图谱构建层 │
|
| 87 |
+
│ knowledge_graph.py │
|
| 88 |
+
│ ├── KnowledgeGraph (图谱管理) │
|
| 89 |
+
│ │ ├── NetworkX图结构 │
|
| 90 |
+
│ │ ├── 社区检测 (Louvain/Greedy/LabelProp) │
|
| 91 |
+
│ │ └── 统计分析 │
|
| 92 |
+
│ └── CommunitySummarizer (社区摘要) │
|
| 93 |
+
└────────────────────────┬────────────────────────────────────┘
|
| 94 |
+
↓
|
| 95 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 96 |
+
│ 索引构建层 │
|
| 97 |
+
│ graph_indexer.py │
|
| 98 |
+
│ └── GraphRAGIndexer │
|
| 99 |
+
│ ├── 5步索引流程 │
|
| 100 |
+
│ └── 图谱持久化 │
|
| 101 |
+
└────────────────────────┬────────────────────────────────────┘
|
| 102 |
+
↓
|
| 103 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 104 |
+
│ 检索查询层 │
|
| 105 |
+
│ graph_retriever.py │
|
| 106 |
+
│ └── GraphRetriever │
|
| 107 |
+
│ ├── 本地查询 (Local Query) │
|
| 108 |
+
│ ├── 全局查询 (Global Query) │
|
| 109 |
+
│ ├── 混合查询 (Hybrid Query) │
|
| 110 |
+
│ └── 智能查询 (Smart Query) │
|
| 111 |
+
└────────────────────────┬────────────────────────────────────┘
|
| 112 |
+
↓
|
| 113 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 114 |
+
│ 应用层 │
|
| 115 |
+
│ main_graphrag.py │
|
| 116 |
+
│ └── AdaptiveRAGWithGraph │
|
| 117 |
+
│ ├── 5种查询模式 │
|
| 118 |
+
│ ├── 统计信息展示 │
|
| 119 |
+
│ └── 交互式界面 │
|
| 120 |
+
└─────────────────────────────────────────────────────────────┘
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
---
|
| 124 |
+
|
| 125 |
+
## 🚀 使用流程
|
| 126 |
+
|
| 127 |
+
### 方式1: 直接运行示例
|
| 128 |
+
|
| 129 |
+
```bash
|
| 130 |
+
# 1. 安装依赖
|
| 131 |
+
pip install -r requirements.txt
|
| 132 |
+
pip install -r requirements_graphrag.txt
|
| 133 |
+
|
| 134 |
+
# 2. 运行GraphRAG示例
|
| 135 |
+
python main_graphrag.py
|
| 136 |
+
|
| 137 |
+
# 首次运行会自动构建索引,后续运行会加载缓存
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
### 方式2: 集成到现有代码
|
| 141 |
+
|
| 142 |
+
```python
|
| 143 |
+
# 在 main.py 中集成
|
| 144 |
+
from config import ENABLE_GRAPHRAG, GRAPHRAG_INDEX_PATH
|
| 145 |
+
from graph_indexer import initialize_graph_indexer
|
| 146 |
+
from graph_retriever import initialize_graph_retriever
|
| 147 |
+
|
| 148 |
+
class AdaptiveRAGSystem:
|
| 149 |
+
def __init__(self):
|
| 150 |
+
# ... 现有初始化 ...
|
| 151 |
+
|
| 152 |
+
if ENABLE_GRAPHRAG:
|
| 153 |
+
# 构建/加载图谱
|
| 154 |
+
self.graph_indexer = initialize_graph_indexer()
|
| 155 |
+
|
| 156 |
+
if os.path.exists(GRAPHRAG_INDEX_PATH):
|
| 157 |
+
self.kg = self.graph_indexer.load_index(GRAPHRAG_INDEX_PATH)
|
| 158 |
+
else:
|
| 159 |
+
self.kg = self.graph_indexer.index_documents(
|
| 160 |
+
self.doc_splits,
|
| 161 |
+
save_path=GRAPHRAG_INDEX_PATH
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# 初始化检索器
|
| 165 |
+
self.graph_retriever = initialize_graph_retriever(self.kg)
|
| 166 |
+
|
| 167 |
+
def query(self, question: str):
|
| 168 |
+
if ENABLE_GRAPHRAG:
|
| 169 |
+
# 使用图谱智能查询
|
| 170 |
+
return self.graph_retriever.smart_query(question)
|
| 171 |
+
else:
|
| 172 |
+
# 原有逻辑
|
| 173 |
+
...
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
---
|
| 177 |
+
|
| 178 |
+
## 📊 功能对比
|
| 179 |
+
|
| 180 |
+
### 原系统 vs GraphRAG增强
|
| 181 |
+
|
| 182 |
+
| 功能 | 原系统 | GraphRAG增强 | 提升 |
|
| 183 |
+
|------|--------|--------------|------|
|
| 184 |
+
| **检索方式** | 向量相似度 | 向量 + 图谱 | ✅ 多模态检索 |
|
| 185 |
+
| **关系理解** | ❌ 无 | ✅ 显式关系 | ✅ 关系推理能力 |
|
| 186 |
+
| **多跳推理** | ❌ 有限 | ✅ 支持N跳 | ✅ 复杂推理 |
|
| 187 |
+
| **全局理解** | ⚠️ 需读取多文档 | ✅ 社区摘要 | ✅ 高效概览 |
|
| 188 |
+
| **实体消歧** | ❌ 无 | ✅ 图谱上下文 | ✅ 准确识别 |
|
| 189 |
+
| **事实验证** | 基于文档匹配 | 基于关系验证 | ✅ 更严格 |
|
| 190 |
+
|
| 191 |
+
---
|
| 192 |
+
|
| 193 |
+
## 🎯 适用场景
|
| 194 |
+
|
| 195 |
+
### GraphRAG特别适合:
|
| 196 |
+
|
| 197 |
+
✅ **知识密集型领域**
|
| 198 |
+
- 学术论文、技术文档
|
| 199 |
+
- 需要理解实体关系
|
| 200 |
+
- 例: "AlphaCodium的作者研究了哪些其他技术?"
|
| 201 |
+
|
| 202 |
+
✅ **需要推理的问题**
|
| 203 |
+
- 多跳关系查询
|
| 204 |
+
- 因果关系分析
|
| 205 |
+
- 例: "提示工程如何应用于对抗性攻击防御?"
|
| 206 |
+
|
| 207 |
+
✅ **概览性问题**
|
| 208 |
+
- 主题归纳
|
| 209 |
+
- 研究趋势
|
| 210 |
+
- 例: "这个领域的主要研究方向有哪些?"
|
| 211 |
+
|
| 212 |
+
### 仍使用向量检索:
|
| 213 |
+
|
| 214 |
+
⚠️ **模糊语义查询**
|
| 215 |
+
- 没有明确实体
|
| 216 |
+
- 需要语义相似匹配
|
| 217 |
+
|
| 218 |
+
⚠️ **最新资讯查询**
|
| 219 |
+
- 图谱未覆盖的新内容
|
| 220 |
+
- 需要网络搜索
|
| 221 |
+
|
| 222 |
+
---
|
| 223 |
+
|
| 224 |
+
## 🔧 配置参数说明
|
| 225 |
+
|
| 226 |
+
```python
|
| 227 |
+
# config.py
|
| 228 |
+
|
| 229 |
+
ENABLE_GRAPHRAG = True
|
| 230 |
+
# 是否启用GraphRAG,False则回退到纯向量检索
|
| 231 |
+
|
| 232 |
+
GRAPHRAG_INDEX_PATH = "./data/knowledge_graph.json"
|
| 233 |
+
# 图谱持久化路径,避免每次重建
|
| 234 |
+
|
| 235 |
+
GRAPHRAG_COMMUNITY_ALGORITHM = "louvain"
|
| 236 |
+
# 社区检测算法:
|
| 237 |
+
# - "louvain": 最优质量(推荐)
|
| 238 |
+
# - "greedy": 更快速度
|
| 239 |
+
# - "label_propagation": 快速近似
|
| 240 |
+
|
| 241 |
+
GRAPHRAG_MAX_HOPS = 2
|
| 242 |
+
# 本地查询时扩展的邻居深度
|
| 243 |
+
# 1: 只看直接邻居
|
| 244 |
+
# 2: 二跳邻居(推荐)
|
| 245 |
+
# 3+: 可能包含过多噪声
|
| 246 |
+
|
| 247 |
+
GRAPHRAG_TOP_K_COMMUNITIES = 5
|
| 248 |
+
# 全局查询时使用的社区数量
|
| 249 |
+
# 更多社区 = 更全面但更慢
|
| 250 |
+
|
| 251 |
+
GRAPHRAG_BATCH_SIZE = 10
|
| 252 |
+
# 实体提取的批处理大小
|
| 253 |
+
# 更大批次 = 更快但更耗内存
|
| 254 |
+
```
|
| 255 |
+
|
| 256 |
+
---
|
| 257 |
+
|
| 258 |
+
## 📈 性能特征
|
| 259 |
+
|
| 260 |
+
### 索引构建时间
|
| 261 |
+
|
| 262 |
+
| 文档数量 | 实体数 | 关系数 | 社区数 | 构建时间* |
|
| 263 |
+
|---------|--------|--------|--------|----------|
|
| 264 |
+
| 10个文档块 | ~50 | ~30 | 3-5 | ~2分钟 |
|
| 265 |
+
| 50个文档块 | ~200 | ~150 | 8-12 | ~8分钟 |
|
| 266 |
+
| 100个文档块 | ~400 | ~300 | 15-20 | ~15分钟 |
|
| 267 |
+
|
| 268 |
+
*基于Mistral模型,实际时间取决于LLM速度
|
| 269 |
+
|
| 270 |
+
### 查询速度
|
| 271 |
+
|
| 272 |
+
| 查询类型 | 平均耗时 | 说明 |
|
| 273 |
+
|---------|---------|------|
|
| 274 |
+
| 本地查询 | 2-5秒 | 需要LLM生成答案 |
|
| 275 |
+
| 全局查询 | 3-8秒 | 需要处理多个社区摘要 |
|
| 276 |
+
| 智能查询 | 2-8秒 | 取决于选择的策略 |
|
| 277 |
+
| 混合查询 | 5-12秒 | 执行两种查询 |
|
| 278 |
+
|
| 279 |
+
### 存储需求
|
| 280 |
+
|
| 281 |
+
- **图谱索引**: 100个文档块 ≈ 1-5 MB (JSON格式)
|
| 282 |
+
- **内存占用**: 运行时 ≈ 200-500 MB (取决于图大小)
|
| 283 |
+
|
| 284 |
+
---
|
| 285 |
+
|
| 286 |
+
## 🐛 故障排查
|
| 287 |
+
|
| 288 |
+
### 问题1: 实体提取失败
|
| 289 |
+
```
|
| 290 |
+
❌ 实体提取失败: timeout
|
| 291 |
+
```
|
| 292 |
+
|
| 293 |
+
**解决方案**:
|
| 294 |
+
- 检查Ollama服务是否运行: `ollama serve`
|
| 295 |
+
- 减少批处理大小: `GRAPHRAG_BATCH_SIZE = 5`
|
| 296 |
+
- 使用更快的LLM模型
|
| 297 |
+
|
| 298 |
+
### 问题2: 社区检测失败
|
| 299 |
+
```
|
| 300 |
+
⚠️ python-louvain未安装
|
| 301 |
+
```
|
| 302 |
+
|
| 303 |
+
**解决方案**:
|
| 304 |
+
```bash
|
| 305 |
+
pip install python-louvain
|
| 306 |
+
# 或使用其他算法
|
| 307 |
+
GRAPHRAG_COMMUNITY_ALGORITHM = "greedy"
|
| 308 |
+
```
|
| 309 |
+
|
| 310 |
+
### 问题3: 查询无结果
|
| 311 |
+
```
|
| 312 |
+
未能在知识图谱中找到相关实体
|
| 313 |
+
```
|
| 314 |
+
|
| 315 |
+
**解决方案**:
|
| 316 |
+
- 检查图谱是否构建: `rag_system.get_graph_statistics()`
|
| 317 |
+
- 使用全局查询代替本地查询
|
| 318 |
+
- 检查实体提取质量
|
| 319 |
+
|
| 320 |
+
### 问题4: 内存不足
|
| 321 |
+
```
|
| 322 |
+
MemoryError
|
| 323 |
+
```
|
| 324 |
+
|
| 325 |
+
**解决方案**:
|
| 326 |
+
- 减少文档数量测试
|
| 327 |
+
- 增加批处理间隔
|
| 328 |
+
- 使用轻量级图存储
|
| 329 |
+
|
| 330 |
+
---
|
| 331 |
+
|
| 332 |
+
## 📝 代码示例
|
| 333 |
+
|
| 334 |
+
### 示例1: 基本使用
|
| 335 |
+
|
| 336 |
+
```python
|
| 337 |
+
from main_graphrag import AdaptiveRAGWithGraph
|
| 338 |
+
|
| 339 |
+
# 初始化系统
|
| 340 |
+
rag = AdaptiveRAGWithGraph(enable_graphrag=True)
|
| 341 |
+
|
| 342 |
+
# 本地查询(针对特定实体)
|
| 343 |
+
answer = rag.query_graph_local("LLM Agent的主要组件是什么?")
|
| 344 |
+
|
| 345 |
+
# 全局查询(概览性问题)
|
| 346 |
+
answer = rag.query_graph_global("这些文档讨论了哪些主题?")
|
| 347 |
+
|
| 348 |
+
# 智能查询(自动选择策略)
|
| 349 |
+
answer = rag.query_smart("如何防御对抗性攻击?")
|
| 350 |
+
```
|
| 351 |
+
|
| 352 |
+
### 示例2: 混合检索
|
| 353 |
+
|
| 354 |
+
```python
|
| 355 |
+
# 同时使用向量和图谱
|
| 356 |
+
result = rag.query_hybrid("提示工程在LLM中的应用")
|
| 357 |
+
|
| 358 |
+
print("向量检索:", result["vector_retrieval"]["context"])
|
| 359 |
+
print("图谱本地:", result["graph_local"])
|
| 360 |
+
print("图谱全局:", result["graph_global"])
|
| 361 |
+
```
|
| 362 |
+
|
| 363 |
+
### 示例3: 手动控制
|
| 364 |
+
|
| 365 |
+
```python
|
| 366 |
+
from graph_indexer import initialize_graph_indexer
|
| 367 |
+
from graph_retriever import initialize_graph_retriever
|
| 368 |
+
|
| 369 |
+
# 构建索引
|
| 370 |
+
indexer = initialize_graph_indexer()
|
| 371 |
+
kg = indexer.index_documents(documents, save_path="my_graph.json")
|
| 372 |
+
|
| 373 |
+
# 查看统计
|
| 374 |
+
stats = kg.get_statistics()
|
| 375 |
+
print(f"实体: {stats['num_nodes']}, 关系: {stats['num_edges']}")
|
| 376 |
+
|
| 377 |
+
# 查询
|
| 378 |
+
retriever = initialize_graph_retriever(kg)
|
| 379 |
+
answer = retriever.local_query("specific question", max_hops=3)
|
| 380 |
+
```
|
| 381 |
+
|
| 382 |
+
---
|
| 383 |
+
|
| 384 |
+
## 🎓 学习资源
|
| 385 |
+
|
| 386 |
+
### 推荐阅读顺序
|
| 387 |
+
|
| 388 |
+
1. **GRAPHRAG_GUIDE.md** - 详细使用指南
|
| 389 |
+
2. **entity_extractor.py** - 了解实体提取
|
| 390 |
+
3. **knowledge_graph.py** - 理解图谱构建
|
| 391 |
+
4. **graph_retriever.py** - 学习查询策略
|
| 392 |
+
5. **main_graphrag.py** - 完整实践示例
|
| 393 |
+
|
| 394 |
+
### 关键概念
|
| 395 |
+
|
| 396 |
+
- **实体 (Entity)**: 图中的节点,如人物、概念、技术
|
| 397 |
+
- **关系 (Relation)**: 图中的边,连接两个实体
|
| 398 |
+
- **社区 (Community)**: 紧密连接的节点群组
|
| 399 |
+
- **本地查询**: 基于实体邻域的精确查询
|
| 400 |
+
- **全局查询**: 基于社区摘要的概览查询
|
| 401 |
+
|
| 402 |
+
---
|
| 403 |
+
|
| 404 |
+
## 🔮 未来计划
|
| 405 |
+
|
| 406 |
+
- [ ] **增量索引**: 添加新文档无需重建整个图谱
|
| 407 |
+
- [ ] **Neo4j集成**: 生产环境使用专业图数据库
|
| 408 |
+
- [ ] **可视化界面**: Web界面展示知识图谱
|
| 409 |
+
- [ ] **多模型融合**: 结合多个LLM提高提取质量
|
| 410 |
+
- [ ] **实时更新**: 动态更新图谱结构
|
| 411 |
+
- [ ] **知识推理**: 基于图谱的推理引擎
|
| 412 |
+
- [ ] **性能优化**: 并行处理、缓存机制
|
| 413 |
+
|
| 414 |
+
---
|
| 415 |
+
|
| 416 |
+
## 📞 支持
|
| 417 |
+
|
| 418 |
+
遇到问题?
|
| 419 |
+
|
| 420 |
+
1. 查看 **GRAPHRAG_GUIDE.md** 的"常见问题"章节
|
| 421 |
+
2. 检查日志输出中的错误信息
|
| 422 |
+
3. 运行 `python main_graphrag.py` 测试基本功能
|
| 423 |
+
4. 使用 `get_graph_statistics()` 检查图谱状态
|
| 424 |
+
|
| 425 |
+
---
|
| 426 |
+
|
| 427 |
+
**总结**: GraphRAG已成功集成到自适应RAG系统中,提供了从实体提取到智能查询的完整工作流。通过合理选择查询策略,可以显著提升复杂问题的回答质量。
|
QUICKSTART.md
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 快速开始指南
|
| 2 |
+
|
| 3 |
+
## 安装依赖
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
pip install -r requirements.txt
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
## 环境配置
|
| 10 |
+
|
| 11 |
+
### 1. 设置API密钥
|
| 12 |
+
|
| 13 |
+
复制 `.env.example` 文件为 `.env`:
|
| 14 |
+
```bash
|
| 15 |
+
cp .env.example .env
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
编辑 `.env` 文件,填入您的实际API密钥:
|
| 19 |
+
```bash
|
| 20 |
+
# Tavily API密钥 - 用于网络搜索功能
|
| 21 |
+
# 从 https://tavily.com/ 获取
|
| 22 |
+
TAVILY_API_KEY=your_actual_tavily_api_key
|
| 23 |
+
|
| 24 |
+
# Nomic API密钥 - 用于文本嵌入服务
|
| 25 |
+
# 从 https://atlas.nomic.ai/ 获取
|
| 26 |
+
NOMIC_API_KEY=your_actual_nomic_api_key
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
### 2. 确保本地模型可用
|
| 30 |
+
|
| 31 |
+
确保您已安装并运行Ollama,并下载了Mistral模型:
|
| 32 |
+
```bash
|
| 33 |
+
# 安装Ollama
|
| 34 |
+
# 访问 https://ollama.ai/ 下载安装
|
| 35 |
+
|
| 36 |
+
# 下载Mistral模型
|
| 37 |
+
ollama pull mistral
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
## 运行系统
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
python main.py
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
## 项目结构
|
| 47 |
+
|
| 48 |
+
```
|
| 49 |
+
adaptive_RAG/
|
| 50 |
+
├── config.py # 配置和环境设置
|
| 51 |
+
├── document_processor.py # 文档处理和向量化
|
| 52 |
+
├── routers_and_graders.py # 路由器和评分器
|
| 53 |
+
├── workflow_nodes.py # 工作流节点
|
| 54 |
+
├── main.py # 主应用程序入口
|
| 55 |
+
├── requirements.txt # 依赖管理
|
| 56 |
+
├── README.md # 项目说明
|
| 57 |
+
└── QUICKSTART.md # 快速开始指南
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
## 功能模块说明
|
| 61 |
+
|
| 62 |
+
1. **config.py**: 包含所有配置项、API密钥管理和环境变量设置
|
| 63 |
+
2. **document_processor.py**: 负责文档加载、分块、向量化和检索器设置
|
| 64 |
+
3. **routers_and_graders.py**: 实现查询路由、文档评分、答案质量评估等功能
|
| 65 |
+
4. **workflow_nodes.py**: 定义所有工作流节点和状态管理
|
| 66 |
+
5. **main.py**: 系统集成和用户交互界面
|
| 67 |
+
|
| 68 |
+
## 使用示例
|
| 69 |
+
|
| 70 |
+
系统启动后会自动进入交互模式,你可以:
|
| 71 |
+
- 询问关于LLM、提示工程、对抗性攻击的问题(使用本地知识库)
|
| 72 |
+
- 询问其他问题(自动路由到网络搜索)
|
| 73 |
+
- 输入 'quit' 或 'exit' 退出系统
|
README.md
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 自适应检索增强生成系统
|
| 2 |
+
|
| 3 |
+
## 📋 项目描述
|
| 4 |
+
|
| 5 |
+
本项目实现了一个**自适应检索增强生成(Adaptive RAG)**系统,能够智能地在本地向量数据库检索和网络搜索之间路由用户查询。系统使用复杂的工作流程,通过自适应选择最佳信息源并通过文档评分和查询转换持续改进响应质量,以提供准确、上下文相关的答案。
|
| 6 |
+
|
| 7 |
+
## ✨ 核心功能
|
| 8 |
+
|
| 9 |
+
### 🔄 智能查询路由
|
| 10 |
+
- 根据查询内容自动判断是使用本地向量存储还是网络搜索
|
| 11 |
+
- 将关于LLM智能体、提示工程和对抗性攻击的问题路由到向量存储
|
| 12 |
+
- 对于一般查询回退到网络搜索
|
| 13 |
+
|
| 14 |
+
### 📚 高级文档处理
|
| 15 |
+
- 使用tiktoken编码器进行网络内容加载和分块
|
| 16 |
+
- 使用Nomic嵌入(本地推理)进行向量嵌入
|
| 17 |
+
- Chroma向量数据库实现高效相似性搜索
|
| 18 |
+
|
| 19 |
+
### 🎯 质量保证流水线
|
| 20 |
+
- **文档相关性评分**:过滤与查询相关的检索文档
|
| 21 |
+
- **答案质量评估**:评估生成的答案是否解决了问题
|
| 22 |
+
- **幻觉检测**:确保回答基于源文档
|
| 23 |
+
- **查询转换**:改进未产生满意结果的查询
|
| 24 |
+
|
| 25 |
+
### 🌐 混合信息检索
|
| 26 |
+
- 用于特定领域知识的本地向量数据库
|
| 27 |
+
- 通过Tavily API集成网络搜索获取最新信息
|
| 28 |
+
- 需要时无缝结合两种信息源
|
| 29 |
+
|
| 30 |
+
## 🛠️ 技术栈
|
| 31 |
+
|
| 32 |
+
### 核心框架
|
| 33 |
+
- **LangChain**:LLM应用程序编排框架
|
| 34 |
+
- **LangGraph**:复杂工作流的状态图实现
|
| 35 |
+
|
| 36 |
+
### 语言模型
|
| 37 |
+
- **Ollama**:本地LLM推理(Mistral模型)
|
| 38 |
+
- **ChatOllama**:Ollama的LangChain集成
|
| 39 |
+
|
| 40 |
+
### 向量数据库与嵌入
|
| 41 |
+
- **Chroma**:用于文档存储和检索的向量数据库
|
| 42 |
+
- **Nomic Embeddings**:本地文本嵌入模型(nomic-embed-text-v1.5)
|
| 43 |
+
|
| 44 |
+
### 文档处理
|
| 45 |
+
- **WebBaseLoader**:网络内容提取
|
| 46 |
+
- **RecursiveCharacterTextSplitter**:使用tiktoken的智能文本分块
|
| 47 |
+
|
| 48 |
+
### 外部API
|
| 49 |
+
- **Tavily API**:网络搜索功能
|
| 50 |
+
- **Nomic API**:嵌入服务
|
| 51 |
+
|
| 52 |
+
### 工作流管理
|
| 53 |
+
- **StateGraph**:管理复杂RAG工作流状态
|
| 54 |
+
- **TypedDict**:类型安全的状态管理
|
| 55 |
+
|
| 56 |
+
## 🏗️ 系统架构
|
| 57 |
+
|
| 58 |
+
系统实现了一个包含以下组件的复杂状态机:
|
| 59 |
+
|
| 60 |
+
```
|
| 61 |
+
用户查询 → 路由器 → [向量存储 | 网络搜索]
|
| 62 |
+
↓ ↓
|
| 63 |
+
检索文档 网络搜索
|
| 64 |
+
↓ ↓
|
| 65 |
+
文档评分 --------→ 生成答案
|
| 66 |
+
↓ ↓
|
| 67 |
+
[足够|不足够] 质量检查
|
| 68 |
+
↓ ↓ ↓ ↓
|
| 69 |
+
生成答案 [有用] [无用] [不支持]
|
| 70 |
+
↓ ↓ ↓ ↓
|
| 71 |
+
质量检查 [结束] 转换查询 重新生成
|
| 72 |
+
↓ ↓ ↓
|
| 73 |
+
[有用|无用|不支持] 重新检索 ↑
|
| 74 |
+
↓ ↓ |
|
| 75 |
+
[最终答案] ←----------
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
### 工作流节点
|
| 79 |
+
1. **retrieve**:从向量存储获取相关文档
|
| 80 |
+
2. **web_search**:搜索网络获取最新信息
|
| 81 |
+
3. **grade_documents**:评估文档相关性
|
| 82 |
+
4. **generate**:使用RAG链创建答案
|
| 83 |
+
5. **transform_query**:改进查询表述
|
| 84 |
+
|
| 85 |
+
### 决策点
|
| 86 |
+
- **route_question**:在向量存储和网络搜索之间选择
|
| 87 |
+
- **decide_to_generate**:判断文档是否足够
|
| 88 |
+
- **grade_generation**:验证答案质量和基础性
|
| 89 |
+
|
| 90 |
+
## 📖 功能模块
|
| 91 |
+
|
| 92 |
+
### 1. **环境设置**
|
| 93 |
+
- 安全的API密钥管理
|
| 94 |
+
- 本地LLM模型配置
|
| 95 |
+
|
| 96 |
+
### 2. **知识库创建**
|
| 97 |
+
- 从指定URL加载网络内容
|
| 98 |
+
- 文档预处理和向量化
|
| 99 |
+
- 向量数据库初始化
|
| 100 |
+
|
| 101 |
+
### 3. **查询处理**
|
| 102 |
+
- 基于内容分析的智能路由
|
| 103 |
+
- 文档检索和相关性评分
|
| 104 |
+
- 查询优化和转换
|
| 105 |
+
|
| 106 |
+
### 4. **答案生成**
|
| 107 |
+
- 上下文感知的响应生成
|
| 108 |
+
- 多源信息综合
|
| 109 |
+
- 质量验证和改进
|
| 110 |
+
|
| 111 |
+
### 5. **质量控制**
|
| 112 |
+
- 幻觉检测
|
| 113 |
+
- 答案相关性评分
|
| 114 |
+
- 迭代改进机制
|
| 115 |
+
|
| 116 |
+
## 🚀 使用示例
|
| 117 |
+
|
| 118 |
+
系统通过自适应工作流处理查询:
|
| 119 |
+
|
| 120 |
+
```python
|
| 121 |
+
# 查询处理示例
|
| 122 |
+
inputs = {"question": "AlphaCodium论文讲的是什么?"}
|
| 123 |
+
for output in app.stream(inputs):
|
| 124 |
+
for key, value in output.items():
|
| 125 |
+
print(f"节点 '{key}':")
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
## 📊 数据源
|
| 129 |
+
|
| 130 |
+
### 默认知识库
|
| 131 |
+
- LLM智能体(Lilian Weng的博客)
|
| 132 |
+
- 提示工程技术
|
| 133 |
+
- LLM对抗性攻击
|
| 134 |
+
|
| 135 |
+
### 动态数据源
|
| 136 |
+
- 实时网络搜索结果
|
| 137 |
+
- 上下文相关的文档检索
|
| 138 |
+
|
| 139 |
+
## 🔧 配置说明
|
| 140 |
+
|
| 141 |
+
### 必需的API密钥
|
| 142 |
+
- `TAVILY_API_KEY`:用于网络搜索功能
|
| 143 |
+
- `NOMIC_API_KEY`:用于嵌入服务
|
| 144 |
+
|
| 145 |
+
### 本地模型
|
| 146 |
+
- **模型**:Mistral(通过Ollama)
|
| 147 |
+
- **温度**:0(确定性响应)
|
| 148 |
+
- **格式**:结构化输出的JSON
|
| 149 |
+
|
| 150 |
+
## 💡 核心创新
|
| 151 |
+
|
| 152 |
+
1. **自适应路由**:基于查询语义的动态源选择
|
| 153 |
+
2. **多层验证**:文档相关性、答案质量和幻觉检查
|
| 154 |
+
3. **自我改进查询**:自动查询转换以获得更好结果
|
| 155 |
+
4. **混合架构**:本地和基于网络的信息源无缝集成
|
| 156 |
+
|
| 157 |
+
这个自适应RAG系统代表了信息检索和生成的先进方法,通过智能工作流管理��持续质量评估确保高质量、相关的响应。
|
RERANKING_PRINCIPLES.md
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 向量重排(Vector Reranking)原理详解
|
| 2 |
+
|
| 3 |
+
## 🎯 什么是向量重排
|
| 4 |
+
|
| 5 |
+
向量重排是检索增强生成(RAG)系统中的一种高级技术,用于在初始向量检索后对候选文档进行重新排序,以提高最终检索结果的质量和相关性。
|
| 6 |
+
|
| 7 |
+
## 🔍 为什么需要重排
|
| 8 |
+
|
| 9 |
+
### 初始检索的局限性
|
| 10 |
+
|
| 11 |
+
1. **语义距离偏差**
|
| 12 |
+
- 向量相似度可能无法完全捕捉语义相关性
|
| 13 |
+
- 某些相关文档可能因为表达方式不同而排名靠后
|
| 14 |
+
|
| 15 |
+
2. **上下文理解不足**
|
| 16 |
+
- 简单的余弦相似度无法理解复杂的查询意图
|
| 17 |
+
- 缺乏对查询和文档交互关系的深度理解
|
| 18 |
+
|
| 19 |
+
3. **多样性问题**
|
| 20 |
+
- 初始检索可能返回内容相似的重复文档
|
| 21 |
+
- 缺乏结果的多样性和全面性
|
| 22 |
+
|
| 23 |
+
## 🧠 重排的核心原理
|
| 24 |
+
|
| 25 |
+
### 1. 双阶段检索架构
|
| 26 |
+
|
| 27 |
+
```
|
| 28 |
+
查询 → 粗排(向量检索)→ 精排(重排模型)→ 最终结果
|
| 29 |
+
↓ ↓
|
| 30 |
+
召回候选集 重新排序打分
|
| 31 |
+
(100-1000篇) (选择前k篇)
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
### 2. 重排模型类型
|
| 35 |
+
|
| 36 |
+
#### A. 交叉编码器(Cross-Encoder)
|
| 37 |
+
```python
|
| 38 |
+
# 原理示意
|
| 39 |
+
def cross_encoder_rerank(query, documents):
|
| 40 |
+
scores = []
|
| 41 |
+
for doc in documents:
|
| 42 |
+
# 查询和文档一起编码
|
| 43 |
+
input_text = f"[CLS] {query} [SEP] {doc} [SEP]"
|
| 44 |
+
score = model(input_text) # 直接输出相关性分数
|
| 45 |
+
scores.append(score)
|
| 46 |
+
return sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
#### B. 双编码器重排(Bi-Encoder Reranking)
|
| 50 |
+
```python
|
| 51 |
+
def bi_encoder_rerank(query, documents):
|
| 52 |
+
query_embedding = query_encoder(query)
|
| 53 |
+
doc_embeddings = [doc_encoder(doc) for doc in documents]
|
| 54 |
+
|
| 55 |
+
# 使用更复杂的相似度计算
|
| 56 |
+
scores = []
|
| 57 |
+
for doc_emb in doc_embeddings:
|
| 58 |
+
score = complex_similarity(query_embedding, doc_emb)
|
| 59 |
+
scores.append(score)
|
| 60 |
+
return sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
## 🔬 重排算法详解
|
| 64 |
+
|
| 65 |
+
### 1. 基于机器学习的重排
|
| 66 |
+
|
| 67 |
+
#### Learning to Rank (LTR)
|
| 68 |
+
```python
|
| 69 |
+
class LearnToRankReranker:
|
| 70 |
+
def __init__(self):
|
| 71 |
+
self.model = None # XGBoost, LambdaMART等
|
| 72 |
+
|
| 73 |
+
def extract_features(self, query, document):
|
| 74 |
+
"""提取查询-文档特征"""
|
| 75 |
+
features = [
|
| 76 |
+
# 文本匹配特征
|
| 77 |
+
jaccard_similarity(query, document),
|
| 78 |
+
tf_idf_score(query, document),
|
| 79 |
+
bm25_score(query, document),
|
| 80 |
+
|
| 81 |
+
# 语义特征
|
| 82 |
+
cosine_similarity(query_emb, doc_emb),
|
| 83 |
+
bert_score(query, document),
|
| 84 |
+
|
| 85 |
+
# 文档特征
|
| 86 |
+
document_length(document),
|
| 87 |
+
document_quality_score(document),
|
| 88 |
+
|
| 89 |
+
# 查询特征
|
| 90 |
+
query_complexity(query),
|
| 91 |
+
query_type_classification(query)
|
| 92 |
+
]
|
| 93 |
+
return features
|
| 94 |
+
|
| 95 |
+
def rerank(self, query, documents):
|
| 96 |
+
features_matrix = []
|
| 97 |
+
for doc in documents:
|
| 98 |
+
features = self.extract_features(query, doc)
|
| 99 |
+
features_matrix.append(features)
|
| 100 |
+
|
| 101 |
+
scores = self.model.predict(features_matrix)
|
| 102 |
+
return sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
### 2. 基于深度学习的重排
|
| 106 |
+
|
| 107 |
+
#### Transformer重排模型
|
| 108 |
+
```python
|
| 109 |
+
class TransformerReranker:
|
| 110 |
+
def __init__(self, model_name="microsoft/DialoGPT-medium"):
|
| 111 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 112 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
| 113 |
+
|
| 114 |
+
def rerank(self, query, documents, top_k=5):
|
| 115 |
+
scores = []
|
| 116 |
+
|
| 117 |
+
for doc in documents:
|
| 118 |
+
# 构造输入
|
| 119 |
+
inputs = self.tokenizer(
|
| 120 |
+
query, doc,
|
| 121 |
+
padding=True,
|
| 122 |
+
truncation=True,
|
| 123 |
+
max_length=512,
|
| 124 |
+
return_tensors="pt"
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# 获取相关性分数
|
| 128 |
+
with torch.no_grad():
|
| 129 |
+
outputs = self.model(**inputs)
|
| 130 |
+
score = torch.softmax(outputs.logits, dim=-1)[0][1] # 相关性概率
|
| 131 |
+
scores.append(score.item())
|
| 132 |
+
|
| 133 |
+
# 重新排序
|
| 134 |
+
ranked_results = sorted(
|
| 135 |
+
zip(documents, scores),
|
| 136 |
+
key=lambda x: x[1],
|
| 137 |
+
reverse=True
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
return ranked_results[:top_k]
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
### 3. 多策略融合重排
|
| 144 |
+
|
| 145 |
+
```python
|
| 146 |
+
class MultiStrategyReranker:
|
| 147 |
+
def __init__(self):
|
| 148 |
+
self.semantic_weight = 0.4
|
| 149 |
+
self.lexical_weight = 0.3
|
| 150 |
+
self.diversity_weight = 0.2
|
| 151 |
+
self.freshness_weight = 0.1
|
| 152 |
+
|
| 153 |
+
def rerank(self, query, documents):
|
| 154 |
+
# 1. 语义相关性分数
|
| 155 |
+
semantic_scores = self.compute_semantic_scores(query, documents)
|
| 156 |
+
|
| 157 |
+
# 2. 词汇匹配分数
|
| 158 |
+
lexical_scores = self.compute_lexical_scores(query, documents)
|
| 159 |
+
|
| 160 |
+
# 3. 多样性分数
|
| 161 |
+
diversity_scores = self.compute_diversity_scores(documents)
|
| 162 |
+
|
| 163 |
+
# 4. 时效性分数
|
| 164 |
+
freshness_scores = self.compute_freshness_scores(documents)
|
| 165 |
+
|
| 166 |
+
# 5. 加权融合
|
| 167 |
+
final_scores = []
|
| 168 |
+
for i in range(len(documents)):
|
| 169 |
+
score = (
|
| 170 |
+
self.semantic_weight * semantic_scores[i] +
|
| 171 |
+
self.lexical_weight * lexical_scores[i] +
|
| 172 |
+
self.diversity_weight * diversity_scores[i] +
|
| 173 |
+
self.freshness_weight * freshness_scores[i]
|
| 174 |
+
)
|
| 175 |
+
final_scores.append(score)
|
| 176 |
+
|
| 177 |
+
return sorted(zip(documents, final_scores), key=lambda x: x[1], reverse=True)
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
## 🎛️ 重排特征工程
|
| 181 |
+
|
| 182 |
+
### 1. 文本匹配特征
|
| 183 |
+
```python
|
| 184 |
+
def extract_text_features(query, document):
|
| 185 |
+
return {
|
| 186 |
+
# 精确匹配
|
| 187 |
+
'exact_match_ratio': exact_match_count(query, document) / len(query.split()),
|
| 188 |
+
|
| 189 |
+
# 模糊匹配
|
| 190 |
+
'fuzzy_match_score': fuzz.ratio(query, document) / 100,
|
| 191 |
+
|
| 192 |
+
# N-gram重叠
|
| 193 |
+
'bigram_overlap': ngram_overlap(query, document, n=2),
|
| 194 |
+
'trigram_overlap': ngram_overlap(query, document, n=3),
|
| 195 |
+
|
| 196 |
+
# TF-IDF相似度
|
| 197 |
+
'tfidf_similarity': tfidf_cosine_similarity(query, document),
|
| 198 |
+
|
| 199 |
+
# BM25分数
|
| 200 |
+
'bm25_score': compute_bm25(query, document)
|
| 201 |
+
}
|
| 202 |
+
```
|
| 203 |
+
|
| 204 |
+
### 2. 语义特征
|
| 205 |
+
```python
|
| 206 |
+
def extract_semantic_features(query, document, embeddings):
|
| 207 |
+
query_emb = embeddings['query']
|
| 208 |
+
doc_emb = embeddings['document']
|
| 209 |
+
|
| 210 |
+
return {
|
| 211 |
+
# 余弦相似度
|
| 212 |
+
'cosine_similarity': cosine_sim(query_emb, doc_emb),
|
| 213 |
+
|
| 214 |
+
# 欧几里得距离
|
| 215 |
+
'euclidean_distance': euclidean_distance(query_emb, doc_emb),
|
| 216 |
+
|
| 217 |
+
# 曼哈顿距离
|
| 218 |
+
'manhattan_distance': manhattan_distance(query_emb, doc_emb),
|
| 219 |
+
|
| 220 |
+
# BERT分数
|
| 221 |
+
'bert_score': bert_score_f1(query, document),
|
| 222 |
+
|
| 223 |
+
# 语义角度
|
| 224 |
+
'semantic_angle': semantic_angle(query_emb, doc_emb)
|
| 225 |
+
}
|
| 226 |
+
```
|
| 227 |
+
|
| 228 |
+
### 3. 文档质量特征
|
| 229 |
+
```python
|
| 230 |
+
def extract_quality_features(document):
|
| 231 |
+
return {
|
| 232 |
+
# 长度特征
|
| 233 |
+
'doc_length': len(document.split()),
|
| 234 |
+
'sentence_count': len(sent_tokenize(document)),
|
| 235 |
+
|
| 236 |
+
# 可读性特征
|
| 237 |
+
'readability_score': textstat.flesch_reading_ease(document),
|
| 238 |
+
'complexity_score': textstat.flesch_kincaid_grade(document),
|
| 239 |
+
|
| 240 |
+
# 信息密度
|
| 241 |
+
'unique_word_ratio': len(set(document.split())) / len(document.split()),
|
| 242 |
+
'stopword_ratio': stopword_count(document) / len(document.split()),
|
| 243 |
+
|
| 244 |
+
# 结构特征
|
| 245 |
+
'has_headers': bool(re.search(r'^#+\s', document, re.MULTILINE)),
|
| 246 |
+
'has_lists': bool(re.search(r'^\s*[-*+]\s', document, re.MULTILINE))
|
| 247 |
+
}
|
| 248 |
+
```
|
| 249 |
+
|
| 250 |
+
## 🚀 实际应用示例
|
| 251 |
+
|
| 252 |
+
### 集成到RAG系统中
|
| 253 |
+
```python
|
| 254 |
+
class AdaptiveRAGWithReranking:
|
| 255 |
+
def __init__(self):
|
| 256 |
+
self.initial_retriever = VectorRetriever()
|
| 257 |
+
self.reranker = TransformerReranker()
|
| 258 |
+
self.generator = LanguageModel()
|
| 259 |
+
|
| 260 |
+
def query(self, question, top_k=5, rerank_candidates=20):
|
| 261 |
+
# 1. 初始检索(获取更多候选)
|
| 262 |
+
initial_docs = self.initial_retriever.retrieve(
|
| 263 |
+
question,
|
| 264 |
+
top_k=rerank_candidates
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# 2. 重排
|
| 268 |
+
reranked_docs = self.reranker.rerank(
|
| 269 |
+
question,
|
| 270 |
+
initial_docs,
|
| 271 |
+
top_k=top_k
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# 3. 生成答案
|
| 275 |
+
context = "\n\n".join([doc[0] for doc in reranked_docs])
|
| 276 |
+
answer = self.generator.generate(question, context)
|
| 277 |
+
|
| 278 |
+
return {
|
| 279 |
+
'answer': answer,
|
| 280 |
+
'sources': reranked_docs,
|
| 281 |
+
'confidence': self.calculate_confidence(reranked_docs)
|
| 282 |
+
}
|
| 283 |
+
```
|
| 284 |
+
|
| 285 |
+
## 📊 性能评估指标
|
| 286 |
+
|
| 287 |
+
### 1. 排序质量指标
|
| 288 |
+
```python
|
| 289 |
+
def evaluate_reranking(original_ranking, reranked_results, ground_truth):
|
| 290 |
+
return {
|
| 291 |
+
# NDCG (Normalized Discounted Cumulative Gain)
|
| 292 |
+
'ndcg@5': ndcg_score(ground_truth, reranked_results, k=5),
|
| 293 |
+
'ndcg@10': ndcg_score(ground_truth, reranked_results, k=10),
|
| 294 |
+
|
| 295 |
+
# MAP (Mean Average Precision)
|
| 296 |
+
'map': mean_average_precision(ground_truth, reranked_results),
|
| 297 |
+
|
| 298 |
+
# MRR (Mean Reciprocal Rank)
|
| 299 |
+
'mrr': mean_reciprocal_rank(ground_truth, reranked_results),
|
| 300 |
+
|
| 301 |
+
# 排序改进度
|
| 302 |
+
'ranking_improvement': kendall_tau(original_ranking, reranked_results)
|
| 303 |
+
}
|
| 304 |
+
```
|
| 305 |
+
|
| 306 |
+
### 2. 端到端效果评估
|
| 307 |
+
```python
|
| 308 |
+
def evaluate_rag_with_reranking(test_questions, ground_truth_answers):
|
| 309 |
+
results = []
|
| 310 |
+
|
| 311 |
+
for question, gt_answer in zip(test_questions, ground_truth_answers):
|
| 312 |
+
# 无重排
|
| 313 |
+
original_answer = rag_without_rerank(question)
|
| 314 |
+
|
| 315 |
+
# 有重排
|
| 316 |
+
reranked_answer = rag_with_rerank(question)
|
| 317 |
+
|
| 318 |
+
results.append({
|
| 319 |
+
'question': question,
|
| 320 |
+
'original_score': evaluate_answer(original_answer, gt_answer),
|
| 321 |
+
'reranked_score': evaluate_answer(reranked_answer, gt_answer),
|
| 322 |
+
'improvement': evaluate_answer(reranked_answer, gt_answer) -
|
| 323 |
+
evaluate_answer(original_answer, gt_answer)
|
| 324 |
+
})
|
| 325 |
+
|
| 326 |
+
return results
|
| 327 |
+
```
|
| 328 |
+
|
| 329 |
+
## 💡 最佳实践
|
| 330 |
+
|
| 331 |
+
### 1. 重排策略选择
|
| 332 |
+
- **实时性要求高**: 使用轻量级规则或简单ML模型
|
| 333 |
+
- **精度要求高**: 使用深度学习重排模型
|
| 334 |
+
- **平衡性能**: 多策略融合 + 缓存优化
|
| 335 |
+
|
| 336 |
+
### 2. 特征选择原则
|
| 337 |
+
- **相关性特征**: 语义相似度、词汇匹配
|
| 338 |
+
- **质量特征**: 文档权威性、完整性
|
| 339 |
+
- **多样性特征**: 避免结果冗余
|
| 340 |
+
- **时效性特征**: 信息新鲜度
|
| 341 |
+
|
| 342 |
+
### 3. 系统优化
|
| 343 |
+
```python
|
| 344 |
+
class OptimizedReranker:
|
| 345 |
+
def __init__(self):
|
| 346 |
+
self.cache = LRUCache(maxsize=1000)
|
| 347 |
+
self.batch_size = 32
|
| 348 |
+
|
| 349 |
+
@lru_cache(maxsize=1000)
|
| 350 |
+
def cached_rerank(self, query_hash, doc_hashes):
|
| 351 |
+
"""缓存重排结果"""
|
| 352 |
+
pass
|
| 353 |
+
|
| 354 |
+
def batch_rerank(self, queries, documents):
|
| 355 |
+
"""批量重排优化"""
|
| 356 |
+
pass
|
| 357 |
+
```
|
| 358 |
+
|
| 359 |
+
重排向量是提升RAG系统检索精度的关键技术,通过多层次的相关性评估和智能排序,显著提高了最终答案的质量和准确性。
|
__pycache__/config.cpython-310.pyc
ADDED
|
Binary file (2.34 kB). View file
|
|
|
__pycache__/config.cpython-313.pyc
ADDED
|
Binary file (2.95 kB). View file
|
|
|
__pycache__/document_processor.cpython-310.pyc
ADDED
|
Binary file (7.21 kB). View file
|
|
|
__pycache__/document_processor.cpython-313.pyc
ADDED
|
Binary file (4.56 kB). View file
|
|
|
__pycache__/entity_extractor.cpython-310.pyc
ADDED
|
Binary file (6.49 kB). View file
|
|
|
__pycache__/graph_indexer.cpython-310.pyc
ADDED
|
Binary file (4.54 kB). View file
|
|
|
__pycache__/graph_retriever.cpython-310.pyc
ADDED
|
Binary file (7.51 kB). View file
|
|
|
__pycache__/knowledge_graph.cpython-310.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
__pycache__/main.cpython-310.pyc
ADDED
|
Binary file (4.39 kB). View file
|
|
|
__pycache__/main.cpython-313.pyc
ADDED
|
Binary file (6.55 kB). View file
|
|
|
__pycache__/reranker.cpython-310.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
__pycache__/routers_and_graders.cpython-310.pyc
ADDED
|
Binary file (6.07 kB). View file
|
|
|
__pycache__/routers_and_graders.cpython-313.pyc
ADDED
|
Binary file (8.04 kB). View file
|
|
|
__pycache__/workflow_nodes.cpython-310.pyc
ADDED
|
Binary file (7.32 kB). View file
|
|
|
__pycache__/workflow_nodes.cpython-313.pyc
ADDED
|
Binary file (8.54 kB). View file
|
|
|
colab_gpu_demo.ipynb
ADDED
|
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# 🚀 GraphRAG GPU检测与测试 - Google Colab版本\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"本Notebook用于在Google Colab上检测GPU可用性并测试GraphRAG系统的性能。\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"## 📋 使用步骤\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"1. **启用GPU**: 运行时 → 更改运行时类型 → 硬件加速器 → GPU (T4)\n",
|
| 14 |
+
"2. **运行所有单元格**: 依次执行下面的代码\n",
|
| 15 |
+
"3. **查看结果**: 检查GPU加速效果\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"---"
|
| 18 |
+
]
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"cell_type": "markdown",
|
| 22 |
+
"metadata": {},
|
| 23 |
+
"source": [
|
| 24 |
+
"## 1️⃣ GPU环境检测"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "code",
|
| 29 |
+
"execution_count": null,
|
| 30 |
+
"metadata": {},
|
| 31 |
+
"outputs": [],
|
| 32 |
+
"source": [
|
| 33 |
+
"# 检测GPU可用性\n",
|
| 34 |
+
"import torch\n",
|
| 35 |
+
"import subprocess\n",
|
| 36 |
+
"import sys\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"print(\"=\"*60)\n",
|
| 39 |
+
"print(\"🔍 GPU环境检测\")\n",
|
| 40 |
+
"print(\"=\"*60)\n",
|
| 41 |
+
"\n",
|
| 42 |
+
"# PyTorch GPU检测\n",
|
| 43 |
+
"cuda_available = torch.cuda.is_available()\n",
|
| 44 |
+
"print(f\"\\n✅ CUDA可用: {cuda_available}\")\n",
|
| 45 |
+
"\n",
|
| 46 |
+
"if cuda_available:\n",
|
| 47 |
+
" print(f\" GPU数量: {torch.cuda.device_count()}\")\n",
|
| 48 |
+
" print(f\" 当前GPU: {torch.cuda.current_device()}\")\n",
|
| 49 |
+
" print(f\" GPU名称: {torch.cuda.get_device_name(0)}\")\n",
|
| 50 |
+
" print(f\" CUDA版本: {torch.version.cuda}\")\n",
|
| 51 |
+
" \n",
|
| 52 |
+
" # 显存信息\n",
|
| 53 |
+
" total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)\n",
|
| 54 |
+
" print(f\" 总显存: {total_memory:.2f} GB\")\n",
|
| 55 |
+
" \n",
|
| 56 |
+
" # nvidia-smi信息\n",
|
| 57 |
+
" print(\"\\n📊 nvidia-smi 输出:\")\n",
|
| 58 |
+
" print(\"-\"*60)\n",
|
| 59 |
+
" !nvidia-smi\n",
|
| 60 |
+
"else:\n",
|
| 61 |
+
" print(\"\\n⚠️ 警告: 未检测到GPU\")\n",
|
| 62 |
+
" print(\" 请检查: 运行时 → 更改运行时类型 → 硬件加速器 → GPU\")\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"print(\"\\n\" + \"=\"*60)"
|
| 65 |
+
]
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"cell_type": "markdown",
|
| 69 |
+
"metadata": {},
|
| 70 |
+
"source": [
|
| 71 |
+
"## 2️⃣ GPU性能基准测试"
|
| 72 |
+
]
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"cell_type": "code",
|
| 76 |
+
"execution_count": null,
|
| 77 |
+
"metadata": {},
|
| 78 |
+
"outputs": [],
|
| 79 |
+
"source": [
|
| 80 |
+
"# GPU vs CPU 性能对比\n",
|
| 81 |
+
"import time\n",
|
| 82 |
+
"import numpy as np\n",
|
| 83 |
+
"\n",
|
| 84 |
+
"print(\"=\"*60)\n",
|
| 85 |
+
"print(\"⚡ GPU vs CPU 矩阵运算性能测试\")\n",
|
| 86 |
+
"print(\"=\"*60)\n",
|
| 87 |
+
"\n",
|
| 88 |
+
"# 测试参数\n",
|
| 89 |
+
"matrix_size = 5000\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"# CPU测试\n",
|
| 92 |
+
"print(f\"\\n🔵 CPU测试 (矩阵大小: {matrix_size}x{matrix_size})\")\n",
|
| 93 |
+
"a_cpu = torch.randn(matrix_size, matrix_size)\n",
|
| 94 |
+
"b_cpu = torch.randn(matrix_size, matrix_size)\n",
|
| 95 |
+
"\n",
|
| 96 |
+
"start = time.time()\n",
|
| 97 |
+
"c_cpu = torch.mm(a_cpu, b_cpu)\n",
|
| 98 |
+
"cpu_time = time.time() - start\n",
|
| 99 |
+
"print(f\" CPU时间: {cpu_time:.2f} 秒\")\n",
|
| 100 |
+
"\n",
|
| 101 |
+
"# GPU测试\n",
|
| 102 |
+
"if cuda_available:\n",
|
| 103 |
+
" print(f\"\\n🟢 GPU测试 (矩阵大小: {matrix_size}x{matrix_size})\")\n",
|
| 104 |
+
" a_gpu = torch.randn(matrix_size, matrix_size).cuda()\n",
|
| 105 |
+
" b_gpu = torch.randn(matrix_size, matrix_size).cuda()\n",
|
| 106 |
+
" \n",
|
| 107 |
+
" # 预热GPU\n",
|
| 108 |
+
" _ = torch.mm(a_gpu, b_gpu)\n",
|
| 109 |
+
" torch.cuda.synchronize()\n",
|
| 110 |
+
" \n",
|
| 111 |
+
" start = time.time()\n",
|
| 112 |
+
" c_gpu = torch.mm(a_gpu, b_gpu)\n",
|
| 113 |
+
" torch.cuda.synchronize()\n",
|
| 114 |
+
" gpu_time = time.time() - start\n",
|
| 115 |
+
" print(f\" GPU时间: {gpu_time:.2f} 秒\")\n",
|
| 116 |
+
" \n",
|
| 117 |
+
" speedup = cpu_time / gpu_time\n",
|
| 118 |
+
" print(f\"\\n🚀 加速比: {speedup:.2f}x\")\n",
|
| 119 |
+
" print(f\" GPU比CPU快 {speedup:.1f} 倍!\")\n",
|
| 120 |
+
"else:\n",
|
| 121 |
+
" print(\"\\n⚠️ 跳过GPU测试(GPU不可用)\")\n",
|
| 122 |
+
"\n",
|
| 123 |
+
"print(\"\\n\" + \"=\"*60)"
|
| 124 |
+
]
|
| 125 |
+
},
|
| 126 |
+
{
|
| 127 |
+
"cell_type": "markdown",
|
| 128 |
+
"metadata": {},
|
| 129 |
+
"source": [
|
| 130 |
+
"## 3️⃣ 安装GraphRAG依赖"
|
| 131 |
+
]
|
| 132 |
+
},
|
| 133 |
+
{
|
| 134 |
+
"cell_type": "code",
|
| 135 |
+
"execution_count": null,
|
| 136 |
+
"metadata": {},
|
| 137 |
+
"outputs": [],
|
| 138 |
+
"source": [
|
| 139 |
+
"# 克隆项目(如果需要)\n",
|
| 140 |
+
"import os\n",
|
| 141 |
+
"\n",
|
| 142 |
+
"print(\"📦 安装GraphRAG依赖...\\n\")\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"# 安装核心依赖\n",
|
| 145 |
+
"!pip install -q langchain langchain-community langchain-core langgraph\n",
|
| 146 |
+
"!pip install -q chromadb sentence-transformers transformers\n",
|
| 147 |
+
"!pip install -q tiktoken beautifulsoup4 requests\n",
|
| 148 |
+
"!pip install -q tavily-python python-dotenv\n",
|
| 149 |
+
"!pip install -q networkx python-louvain\n",
|
| 150 |
+
"!pip install -q torch --index-url https://download.pytorch.org/whl/cu118\n",
|
| 151 |
+
"\n",
|
| 152 |
+
"print(\"\\n✅ 依赖安装完成!\")"
|
| 153 |
+
]
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"cell_type": "markdown",
|
| 157 |
+
"metadata": {},
|
| 158 |
+
"source": [
|
| 159 |
+
"## 4️⃣ 上传项目文件\n",
|
| 160 |
+
"\n",
|
| 161 |
+
"**选项A**: 从GitHub克隆\n",
|
| 162 |
+
"```python\n",
|
| 163 |
+
"!git clone https://github.com/your-repo/adaptive_RAG.git\n",
|
| 164 |
+
"%cd adaptive_RAG\n",
|
| 165 |
+
"```\n",
|
| 166 |
+
"\n",
|
| 167 |
+
"**选项B**: 手动上传文件到Colab\n",
|
| 168 |
+
"- 使用左侧文件浏览器上传以下核心文件:\n",
|
| 169 |
+
" - `config.py`\n",
|
| 170 |
+
" - `entity_extractor.py`\n",
|
| 171 |
+
" - `knowledge_graph.py`\n",
|
| 172 |
+
" - `graph_indexer.py`\n",
|
| 173 |
+
" - `graph_retriever.py`\n",
|
| 174 |
+
" - `.env` (包含API密钥)"
|
| 175 |
+
]
|
| 176 |
+
},
|
| 177 |
+
{
|
| 178 |
+
"cell_type": "code",
|
| 179 |
+
"execution_count": null,
|
| 180 |
+
"metadata": {},
|
| 181 |
+
"outputs": [],
|
| 182 |
+
"source": [
|
| 183 |
+
"# 创建必要的目录\n",
|
| 184 |
+
"!mkdir -p data\n",
|
| 185 |
+
"\n",
|
| 186 |
+
"# 如果使用选项A,运行下面的命令\n",
|
| 187 |
+
"# !git clone YOUR_REPO_URL\n",
|
| 188 |
+
"# %cd adaptive_RAG\n",
|
| 189 |
+
"\n",
|
| 190 |
+
"print(\"✅ 目录准备完成\")"
|
| 191 |
+
]
|
| 192 |
+
},
|
| 193 |
+
{
|
| 194 |
+
"cell_type": "markdown",
|
| 195 |
+
"metadata": {},
|
| 196 |
+
"source": [
|
| 197 |
+
"## 5️⃣ 配置API密钥"
|
| 198 |
+
]
|
| 199 |
+
},
|
| 200 |
+
{
|
| 201 |
+
"cell_type": "code",
|
| 202 |
+
"execution_count": null,
|
| 203 |
+
"metadata": {},
|
| 204 |
+
"outputs": [],
|
| 205 |
+
"source": [
|
| 206 |
+
"# 设置API密钥(替换为您的真实密钥)\n",
|
| 207 |
+
"import os\n",
|
| 208 |
+
"from getpass import getpass\n",
|
| 209 |
+
"\n",
|
| 210 |
+
"print(\"🔑 配置API密钥\\n\")\n",
|
| 211 |
+
"\n",
|
| 212 |
+
"# 方式1: 直接设置(不安全,仅用于测试)\n",
|
| 213 |
+
"# os.environ['TAVILY_API_KEY'] = 'your_tavily_api_key_here'\n",
|
| 214 |
+
"\n",
|
| 215 |
+
"# 方式2: 安全输入\n",
|
| 216 |
+
"if 'TAVILY_API_KEY' not in os.environ:\n",
|
| 217 |
+
" os.environ['TAVILY_API_KEY'] = getpass('输入 TAVILY_API_KEY: ')\n",
|
| 218 |
+
" print(\"✅ TAVILY_API_KEY 已设置\")\n",
|
| 219 |
+
"else:\n",
|
| 220 |
+
" print(\"✅ TAVILY_API_KEY 已存在\")\n",
|
| 221 |
+
"\n",
|
| 222 |
+
"print(\"\\n注意: GraphRAG在Colab上使用HuggingFace嵌入,不需要NOMIC_API_KEY\")"
|
| 223 |
+
]
|
| 224 |
+
},
|
| 225 |
+
{
|
| 226 |
+
"cell_type": "markdown",
|
| 227 |
+
"metadata": {},
|
| 228 |
+
"source": [
|
| 229 |
+
"## 6️⃣ 简化版GraphRAG测试代码"
|
| 230 |
+
]
|
| 231 |
+
},
|
| 232 |
+
{
|
| 233 |
+
"cell_type": "code",
|
| 234 |
+
"execution_count": null,
|
| 235 |
+
"metadata": {},
|
| 236 |
+
"outputs": [],
|
| 237 |
+
"source": [
|
| 238 |
+
"# 简化版GraphRAG核心组件\n",
|
| 239 |
+
"# 适用于Colab快速测试,无需完整项目文件\n",
|
| 240 |
+
"\n",
|
| 241 |
+
"from typing import List, Dict\n",
|
| 242 |
+
"import networkx as nx\n",
|
| 243 |
+
"from sentence_transformers import SentenceTransformer\n",
|
| 244 |
+
"import torch\n",
|
| 245 |
+
"\n",
|
| 246 |
+
"class SimpleGraphRAG:\n",
|
| 247 |
+
" \"\"\"简化版GraphRAG用于GPU性能测试\"\"\"\n",
|
| 248 |
+
" \n",
|
| 249 |
+
" def __init__(self, use_gpu=True):\n",
|
| 250 |
+
" print(\"🚀 初始化SimpleGraphRAG...\")\n",
|
| 251 |
+
" \n",
|
| 252 |
+
" # 检测设备\n",
|
| 253 |
+
" self.device = 'cuda' if use_gpu and torch.cuda.is_available() else 'cpu'\n",
|
| 254 |
+
" print(f\" 设备: {self.device.upper()}\")\n",
|
| 255 |
+
" \n",
|
| 256 |
+
" # 加载嵌入模型\n",
|
| 257 |
+
" print(f\" 加载嵌入模型...\")\n",
|
| 258 |
+
" self.embedder = SentenceTransformer(\n",
|
| 259 |
+
" 'sentence-transformers/all-MiniLM-L6-v2',\n",
|
| 260 |
+
" device=self.device\n",
|
| 261 |
+
" )\n",
|
| 262 |
+
" \n",
|
| 263 |
+
" # 知识图谱\n",
|
| 264 |
+
" self.graph = nx.Graph()\n",
|
| 265 |
+
" self.entities = {}\n",
|
| 266 |
+
" \n",
|
| 267 |
+
" print(\"✅ 初始化完成!\")\n",
|
| 268 |
+
" \n",
|
| 269 |
+
" def add_sample_data(self):\n",
|
| 270 |
+
" \"\"\"添加示例数据\"\"\"\n",
|
| 271 |
+
" print(\"\\n📊 添加示例数据...\")\n",
|
| 272 |
+
" \n",
|
| 273 |
+
" # 示例实体\n",
|
| 274 |
+
" entities = [\n",
|
| 275 |
+
" {\"name\": \"LLM\", \"type\": \"CONCEPT\", \"desc\": \"大语言模型\"},\n",
|
| 276 |
+
" {\"name\": \"GPT\", \"type\": \"TECHNOLOGY\", \"desc\": \"生成式预训练转换器\"},\n",
|
| 277 |
+
" {\"name\": \"Transformer\", \"type\": \"CONCEPT\", \"desc\": \"注意力机制架构\"},\n",
|
| 278 |
+
" {\"name\": \"OpenAI\", \"type\": \"ORGANIZATION\", \"desc\": \"人工智能研究公司\"},\n",
|
| 279 |
+
" {\"name\": \"Attention\", \"type\": \"CONCEPT\", \"desc\": \"注意力机制\"},\n",
|
| 280 |
+
" ]\n",
|
| 281 |
+
" \n",
|
| 282 |
+
" for entity in entities:\n",
|
| 283 |
+
" self.graph.add_node(\n",
|
| 284 |
+
" entity[\"name\"],\n",
|
| 285 |
+
" type=entity[\"type\"],\n",
|
| 286 |
+
" description=entity[\"desc\"]\n",
|
| 287 |
+
" )\n",
|
| 288 |
+
" self.entities[entity[\"name\"]] = entity\n",
|
| 289 |
+
" \n",
|
| 290 |
+
" # 示例关系\n",
|
| 291 |
+
" relations = [\n",
|
| 292 |
+
" (\"GPT\", \"LLM\", \"IS_A\"),\n",
|
| 293 |
+
" (\"GPT\", \"Transformer\", \"USES\"),\n",
|
| 294 |
+
" (\"Transformer\", \"Attention\", \"CONTAINS\"),\n",
|
| 295 |
+
" (\"OpenAI\", \"GPT\", \"DEVELOPS\"),\n",
|
| 296 |
+
" ]\n",
|
| 297 |
+
" \n",
|
| 298 |
+
" for source, target, rel_type in relations:\n",
|
| 299 |
+
" self.graph.add_edge(source, target, relation=rel_type)\n",
|
| 300 |
+
" \n",
|
| 301 |
+
" print(f\" ✅ 添加了 {len(entities)} 个实体\")\n",
|
| 302 |
+
" print(f\" ✅ 添加了 {len(relations)} 个关系\")\n",
|
| 303 |
+
" \n",
|
| 304 |
+
" def test_gpu_embedding(self, texts: List[str]):\n",
|
| 305 |
+
" \"\"\"测试GPU嵌入性能\"\"\"\n",
|
| 306 |
+
" print(f\"\\n⚡ 测试嵌入性能 ({len(texts)} 个文本)...\")\n",
|
| 307 |
+
" \n",
|
| 308 |
+
" import time\n",
|
| 309 |
+
" \n",
|
| 310 |
+
" start = time.time()\n",
|
| 311 |
+
" embeddings = self.embedder.encode(\n",
|
| 312 |
+
" texts,\n",
|
| 313 |
+
" show_progress_bar=True,\n",
|
| 314 |
+
" batch_size=32\n",
|
| 315 |
+
" )\n",
|
| 316 |
+
" elapsed = time.time() - start\n",
|
| 317 |
+
" \n",
|
| 318 |
+
" print(f\" ✅ 完成! 耗时: {elapsed:.2f}秒\")\n",
|
| 319 |
+
" print(f\" 📊 嵌入维度: {embeddings.shape}\")\n",
|
| 320 |
+
" print(f\" 🚀 速度: {len(texts)/elapsed:.1f} 文本/秒\")\n",
|
| 321 |
+
" \n",
|
| 322 |
+
" return embeddings\n",
|
| 323 |
+
" \n",
|
| 324 |
+
" def query(self, question: str):\n",
|
| 325 |
+
" \"\"\"简单查询\"\"\"\n",
|
| 326 |
+
" print(f\"\\n🔍 查询: {question}\")\n",
|
| 327 |
+
" \n",
|
| 328 |
+
" # 简单的关键词匹配\n",
|
| 329 |
+
" results = []\n",
|
| 330 |
+
" for entity_name in self.entities:\n",
|
| 331 |
+
" if entity_name.lower() in question.lower():\n",
|
| 332 |
+
" neighbors = list(self.graph.neighbors(entity_name))\n",
|
| 333 |
+
" results.append({\n",
|
| 334 |
+
" \"entity\": entity_name,\n",
|
| 335 |
+
" \"info\": self.entities[entity_name],\n",
|
| 336 |
+
" \"neighbors\": neighbors\n",
|
| 337 |
+
" })\n",
|
| 338 |
+
" \n",
|
| 339 |
+
" print(f\"\\n📋 找到 {len(results)} 个相关实体:\")\n",
|
| 340 |
+
" for r in results:\n",
|
| 341 |
+
" print(f\" • {r['entity']} ({r['info']['type']})\")\n",
|
| 342 |
+
" print(f\" 描述: {r['info']['desc']}\")\n",
|
| 343 |
+
" print(f\" 关联: {', '.join(r['neighbors'])}\")\n",
|
| 344 |
+
" \n",
|
| 345 |
+
" return results\n",
|
| 346 |
+
"\n",
|
| 347 |
+
"print(\"✅ SimpleGraphRAG类定义完成\")"
|
| 348 |
+
]
|
| 349 |
+
},
|
| 350 |
+
{
|
| 351 |
+
"cell_type": "markdown",
|
| 352 |
+
"metadata": {},
|
| 353 |
+
"source": [
|
| 354 |
+
"## 7️⃣ 运行GPU性能测试"
|
| 355 |
+
]
|
| 356 |
+
},
|
| 357 |
+
{
|
| 358 |
+
"cell_type": "code",
|
| 359 |
+
"execution_count": null,
|
| 360 |
+
"metadata": {},
|
| 361 |
+
"outputs": [],
|
| 362 |
+
"source": [
|
| 363 |
+
"# 初始化GraphRAG(GPU版本)\n",
|
| 364 |
+
"print(\"=\"*60)\n",
|
| 365 |
+
"print(\"🎯 GraphRAG GPU性能测试\")\n",
|
| 366 |
+
"print(\"=\"*60)\n",
|
| 367 |
+
"\n",
|
| 368 |
+
"graph_rag = SimpleGraphRAG(use_gpu=True)\n",
|
| 369 |
+
"\n",
|
| 370 |
+
"# 添加示例数据\n",
|
| 371 |
+
"graph_rag.add_sample_data()\n",
|
| 372 |
+
"\n",
|
| 373 |
+
"# 准备测试文本\n",
|
| 374 |
+
"test_texts = [\n",
|
| 375 |
+
" \"Large Language Models are transforming AI\",\n",
|
| 376 |
+
" \"GPT uses Transformer architecture\",\n",
|
| 377 |
+
" \"Attention mechanism is key to modern NLP\",\n",
|
| 378 |
+
" \"OpenAI develops cutting-edge AI models\",\n",
|
| 379 |
+
"] * 25 # 100个文本\n",
|
| 380 |
+
"\n",
|
| 381 |
+
"print(f\"\\n准备了 {len(test_texts)} 个测试文本\")\n",
|
| 382 |
+
"\n",
|
| 383 |
+
"# GPU嵌入测试\n",
|
| 384 |
+
"embeddings = graph_rag.test_gpu_embedding(test_texts)\n",
|
| 385 |
+
"\n",
|
| 386 |
+
"# 测试查询\n",
|
| 387 |
+
"graph_rag.query(\"What is GPT?\")\n",
|
| 388 |
+
"graph_rag.query(\"Tell me about Transformer\")\n",
|
| 389 |
+
"\n",
|
| 390 |
+
"print(\"\\n\" + \"=\"*60)\n",
|
| 391 |
+
"print(\"✅ GPU性能测试完成!\")\n",
|
| 392 |
+
"print(\"=\"*60)"
|
| 393 |
+
]
|
| 394 |
+
},
|
| 395 |
+
{
|
| 396 |
+
"cell_type": "markdown",
|
| 397 |
+
"metadata": {},
|
| 398 |
+
"source": [
|
| 399 |
+
"## 8️⃣ CPU vs GPU 性能对比"
|
| 400 |
+
]
|
| 401 |
+
},
|
| 402 |
+
{
|
| 403 |
+
"cell_type": "code",
|
| 404 |
+
"execution_count": null,
|
| 405 |
+
"metadata": {},
|
| 406 |
+
"outputs": [],
|
| 407 |
+
"source": [
|
| 408 |
+
"# CPU vs GPU 嵌入性能对比\n",
|
| 409 |
+
"import time\n",
|
| 410 |
+
"\n",
|
| 411 |
+
"print(\"=\"*60)\n",
|
| 412 |
+
"print(\"📊 CPU vs GPU 嵌入性能对比\")\n",
|
| 413 |
+
"print(\"=\"*60)\n",
|
| 414 |
+
"\n",
|
| 415 |
+
"# 准备大量测试文本\n",
|
| 416 |
+
"large_test_texts = test_texts * 10 # 1000个文本\n",
|
| 417 |
+
"print(f\"\\n测试数据: {len(large_test_texts)} 个文本\\n\")\n",
|
| 418 |
+
"\n",
|
| 419 |
+
"# CPU测试\n",
|
| 420 |
+
"print(\"🔵 CPU测试...\")\n",
|
| 421 |
+
"graph_rag_cpu = SimpleGraphRAG(use_gpu=False)\n",
|
| 422 |
+
"start = time.time()\n",
|
| 423 |
+
"embeddings_cpu = graph_rag_cpu.embedder.encode(\n",
|
| 424 |
+
" large_test_texts,\n",
|
| 425 |
+
" show_progress_bar=False,\n",
|
| 426 |
+
" batch_size=32\n",
|
| 427 |
+
")\n",
|
| 428 |
+
"cpu_time = time.time() - start\n",
|
| 429 |
+
"print(f\" CPU时间: {cpu_time:.2f}秒\")\n",
|
| 430 |
+
"print(f\" 速度: {len(large_test_texts)/cpu_time:.1f} 文本/秒\")\n",
|
| 431 |
+
"\n",
|
| 432 |
+
"# GPU测试\n",
|
| 433 |
+
"if cuda_available:\n",
|
| 434 |
+
" print(\"\\n🟢 GPU测试...\")\n",
|
| 435 |
+
" graph_rag_gpu = SimpleGraphRAG(use_gpu=True)\n",
|
| 436 |
+
" start = time.time()\n",
|
| 437 |
+
" embeddings_gpu = graph_rag_gpu.embedder.encode(\n",
|
| 438 |
+
" large_test_texts,\n",
|
| 439 |
+
" show_progress_bar=False,\n",
|
| 440 |
+
" batch_size=32\n",
|
| 441 |
+
" )\n",
|
| 442 |
+
" gpu_time = time.time() - start\n",
|
| 443 |
+
" print(f\" GPU时间: {gpu_time:.2f}秒\")\n",
|
| 444 |
+
" print(f\" 速度: {len(large_test_texts)/gpu_time:.1f} 文本/秒\")\n",
|
| 445 |
+
" \n",
|
| 446 |
+
" speedup = cpu_time / gpu_time\n",
|
| 447 |
+
" print(f\"\\n🚀 加速比: {speedup:.2f}x\")\n",
|
| 448 |
+
" print(f\" GPU比CPU快 {speedup:.1f} 倍!\")\n",
|
| 449 |
+
" \n",
|
| 450 |
+
" # 节省的时间\n",
|
| 451 |
+
" time_saved = cpu_time - gpu_time\n",
|
| 452 |
+
" print(f\" ⏱️ 节省时间: {time_saved:.2f}秒\")\n",
|
| 453 |
+
"else:\n",
|
| 454 |
+
" print(\"\\n⚠️ GPU不可用,跳过GPU测试\")\n",
|
| 455 |
+
"\n",
|
| 456 |
+
"print(\"\\n\" + \"=\"*60)"
|
| 457 |
+
]
|
| 458 |
+
},
|
| 459 |
+
{
|
| 460 |
+
"cell_type": "markdown",
|
| 461 |
+
"metadata": {},
|
| 462 |
+
"source": [
|
| 463 |
+
"## 9️⃣ 显存使用监控"
|
| 464 |
+
]
|
| 465 |
+
},
|
| 466 |
+
{
|
| 467 |
+
"cell_type": "code",
|
| 468 |
+
"execution_count": null,
|
| 469 |
+
"metadata": {},
|
| 470 |
+
"outputs": [],
|
| 471 |
+
"source": [
|
| 472 |
+
"# 监控GPU显存使用\n",
|
| 473 |
+
"if cuda_available:\n",
|
| 474 |
+
" print(\"=\"*60)\n",
|
| 475 |
+
" print(\"💾 GPU显存使用情况\")\n",
|
| 476 |
+
" print(\"=\"*60)\n",
|
| 477 |
+
" \n",
|
| 478 |
+
" allocated = torch.cuda.memory_allocated(0) / (1024**3)\n",
|
| 479 |
+
" reserved = torch.cuda.memory_reserved(0) / (1024**3)\n",
|
| 480 |
+
" total = torch.cuda.get_device_properties(0).total_memory / (1024**3)\n",
|
| 481 |
+
" \n",
|
| 482 |
+
" print(f\"\\n已分配: {allocated:.2f} GB\")\n",
|
| 483 |
+
" print(f\"已保留: {reserved:.2f} GB\")\n",
|
| 484 |
+
" print(f\"总显存: {total:.2f} GB\")\n",
|
| 485 |
+
" print(f\"使用率: {(allocated/total)*100:.1f}%\")\n",
|
| 486 |
+
" \n",
|
| 487 |
+
" print(\"\\n详细信息:\")\n",
|
| 488 |
+
" print(torch.cuda.memory_summary(0, abbreviated=True))\n",
|
| 489 |
+
" \n",
|
| 490 |
+
" print(\"\\n\" + \"=\"*60)\n",
|
| 491 |
+
"else:\n",
|
| 492 |
+
" print(\"⚠️ GPU不可用\")"
|
| 493 |
+
]
|
| 494 |
+
},
|
| 495 |
+
{
|
| 496 |
+
"cell_type": "markdown",
|
| 497 |
+
"metadata": {},
|
| 498 |
+
"source": [
|
| 499 |
+
"## 🔟 性能总结报告"
|
| 500 |
+
]
|
| 501 |
+
},
|
| 502 |
+
{
|
| 503 |
+
"cell_type": "code",
|
| 504 |
+
"execution_count": null,
|
| 505 |
+
"metadata": {},
|
| 506 |
+
"outputs": [],
|
| 507 |
+
"source": [
|
| 508 |
+
"# 生成性能报告\n",
|
| 509 |
+
"print(\"=\"*60)\n",
|
| 510 |
+
"print(\"📈 GraphRAG GPU性能测试报告\")\n",
|
| 511 |
+
"print(\"=\"*60)\n",
|
| 512 |
+
"\n",
|
| 513 |
+
"print(\"\\n🖥️ 硬件信息:\")\n",
|
| 514 |
+
"if cuda_available:\n",
|
| 515 |
+
" print(f\" GPU型号: {torch.cuda.get_device_name(0)}\")\n",
|
| 516 |
+
" print(f\" 显存: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.2f} GB\")\n",
|
| 517 |
+
" print(f\" CUDA版本: {torch.version.cuda}\")\n",
|
| 518 |
+
"else:\n",
|
| 519 |
+
" print(\" ⚠️ GPU不可用\")\n",
|
| 520 |
+
"\n",
|
| 521 |
+
"print(f\"\\n PyTorch版本: {torch.__version__}\")\n",
|
| 522 |
+
"print(f\" Python版本: {sys.version.split()[0]}\")\n",
|
| 523 |
+
"\n",
|
| 524 |
+
"print(\"\\n⚡ 性能测试结果:\")\n",
|
| 525 |
+
"print(f\" 矩阵运算加速: ~{speedup if cuda_available else 'N/A'}x\")\n",
|
| 526 |
+
"print(f\" 文本嵌入加速: ~{cpu_time/gpu_time if cuda_available else 'N/A'}x\")\n",
|
| 527 |
+
"\n",
|
| 528 |
+
"print(\"\\n💡 建议:\")\n",
|
| 529 |
+
"if cuda_available:\n",
|
| 530 |
+
" print(\" ✅ GPU运行良好!建议在Colab上运行完整的GraphRAG索引构建\")\n",
|
| 531 |
+
" print(\" ✅ 预计索引构建时间将大幅缩短\")\n",
|
| 532 |
+
" print(\" ✅ 可以处理更大规模的文档集\")\n",
|
| 533 |
+
"else:\n",
|
| 534 |
+
" print(\" ⚠️ 建议启用GPU以获得最佳性能\")\n",
|
| 535 |
+
" print(\" ⚠️ 路径: 运行时 → 更改运行时类型 → GPU\")\n",
|
| 536 |
+
"\n",
|
| 537 |
+
"print(\"\\n\" + \"=\"*60)\n",
|
| 538 |
+
"print(\"✅ 测试完成!\")\n",
|
| 539 |
+
"print(\"=\"*60)"
|
| 540 |
+
]
|
| 541 |
+
},
|
| 542 |
+
{
|
| 543 |
+
"cell_type": "markdown",
|
| 544 |
+
"metadata": {},
|
| 545 |
+
"source": [
|
| 546 |
+
"---\n",
|
| 547 |
+
"\n",
|
| 548 |
+
"## 📚 下一步\n",
|
| 549 |
+
"\n",
|
| 550 |
+
"如果GPU测试成功,您可以:\n",
|
| 551 |
+
"\n",
|
| 552 |
+
"1. **上传完整项目**: 将整个adaptive_RAG项目上传到Colab\n",
|
| 553 |
+
"2. **运行GraphRAG索引**: 使用GPU加速构建知识图谱\n",
|
| 554 |
+
"3. **保存结果**: 将构建好的图谱下载到本地\n",
|
| 555 |
+
"\n",
|
| 556 |
+
"### 运行完整GraphRAG的命令:\n",
|
| 557 |
+
"\n",
|
| 558 |
+
"```python\n",
|
| 559 |
+
"# 上传项目后运行\n",
|
| 560 |
+
"!python main_graphrag.py\n",
|
| 561 |
+
"```\n",
|
| 562 |
+
"\n",
|
| 563 |
+
"### 预期加速效果:\n",
|
| 564 |
+
"\n",
|
| 565 |
+
"- 实体提取: 使用GPU的LLM推理会更快\n",
|
| 566 |
+
"- 文本嵌入: **5-10倍加速**\n",
|
| 567 |
+
"- 向量相似度计算: **10-20倍加速**\n",
|
| 568 |
+
"- 总体索引构建时间: **3-5倍加速**\n",
|
| 569 |
+
"\n",
|
| 570 |
+
"---"
|
| 571 |
+
]
|
| 572 |
+
}
|
| 573 |
+
],
|
| 574 |
+
"metadata": {
|
| 575 |
+
"accelerator": "GPU",
|
| 576 |
+
"kernelspec": {
|
| 577 |
+
"display_name": "Python 3",
|
| 578 |
+
"language": "python",
|
| 579 |
+
"name": "python3"
|
| 580 |
+
},
|
| 581 |
+
"language_info": {
|
| 582 |
+
"name": "python",
|
| 583 |
+
"version": "3.10.0"
|
| 584 |
+
}
|
| 585 |
+
},
|
| 586 |
+
"nbformat": 4,
|
| 587 |
+
"nbformat_minor": 0
|
| 588 |
+
}
|
colab_gpu_test.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Google Colab GPU检测和GraphRAG性能测试脚本
|
| 4 |
+
可以直接在Colab中运行:python colab_gpu_test.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
import time
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
from typing import List, Dict
|
| 12 |
+
|
| 13 |
+
def print_section(title: str):
|
| 14 |
+
"""打印分节标题"""
|
| 15 |
+
print("\n" + "="*60)
|
| 16 |
+
print(f"{title}")
|
| 17 |
+
print("="*60 + "\n")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def test_gpu_availability():
|
| 21 |
+
"""测试GPU可用性"""
|
| 22 |
+
print_section("🔍 GPU环境检测")
|
| 23 |
+
|
| 24 |
+
cuda_available = torch.cuda.is_available()
|
| 25 |
+
print(f"✅ CUDA可用: {cuda_available}")
|
| 26 |
+
|
| 27 |
+
if cuda_available:
|
| 28 |
+
print(f" GPU数量: {torch.cuda.device_count()}")
|
| 29 |
+
print(f" 当前GPU: {torch.cuda.current_device()}")
|
| 30 |
+
print(f" GPU名称: {torch.cuda.get_device_name(0)}")
|
| 31 |
+
print(f" CUDA版本: {torch.version.cuda}")
|
| 32 |
+
|
| 33 |
+
total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
| 34 |
+
print(f" 总显存: {total_memory:.2f} GB")
|
| 35 |
+
|
| 36 |
+
return True
|
| 37 |
+
else:
|
| 38 |
+
print("\n⚠️ 警告: 未检测到GPU")
|
| 39 |
+
print(" 在Colab中启用GPU: 运行时 → 更改运行时类型 → GPU")
|
| 40 |
+
return False
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def benchmark_matrix_multiplication(matrix_size=5000):
|
| 44 |
+
"""GPU vs CPU 矩阵运算性能测试"""
|
| 45 |
+
print_section("⚡ GPU vs CPU 矩阵运算性能测试")
|
| 46 |
+
|
| 47 |
+
print(f"矩阵大小: {matrix_size}x{matrix_size}\n")
|
| 48 |
+
|
| 49 |
+
# CPU测试
|
| 50 |
+
print("🔵 CPU测试...")
|
| 51 |
+
a_cpu = torch.randn(matrix_size, matrix_size)
|
| 52 |
+
b_cpu = torch.randn(matrix_size, matrix_size)
|
| 53 |
+
|
| 54 |
+
start = time.time()
|
| 55 |
+
c_cpu = torch.mm(a_cpu, b_cpu)
|
| 56 |
+
cpu_time = time.time() - start
|
| 57 |
+
print(f" CPU时间: {cpu_time:.2f} 秒")
|
| 58 |
+
|
| 59 |
+
# GPU测试
|
| 60 |
+
if torch.cuda.is_available():
|
| 61 |
+
print("\n🟢 GPU测试...")
|
| 62 |
+
a_gpu = torch.randn(matrix_size, matrix_size).cuda()
|
| 63 |
+
b_gpu = torch.randn(matrix_size, matrix_size).cuda()
|
| 64 |
+
|
| 65 |
+
# 预热GPU
|
| 66 |
+
_ = torch.mm(a_gpu, b_gpu)
|
| 67 |
+
torch.cuda.synchronize()
|
| 68 |
+
|
| 69 |
+
start = time.time()
|
| 70 |
+
c_gpu = torch.mm(a_gpu, b_gpu)
|
| 71 |
+
torch.cuda.synchronize()
|
| 72 |
+
gpu_time = time.time() - start
|
| 73 |
+
print(f" GPU时间: {gpu_time:.2f} 秒")
|
| 74 |
+
|
| 75 |
+
speedup = cpu_time / gpu_time
|
| 76 |
+
print(f"\n🚀 加速比: {speedup:.2f}x")
|
| 77 |
+
print(f" GPU比CPU快 {speedup:.1f} 倍!")
|
| 78 |
+
|
| 79 |
+
return speedup
|
| 80 |
+
else:
|
| 81 |
+
print("\n⚠️ 跳过GPU测试(GPU不可用)")
|
| 82 |
+
return 1.0
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def test_text_embedding_performance():
|
| 86 |
+
"""测试文本嵌入性能(需要sentence-transformers)"""
|
| 87 |
+
print_section("📝 文本嵌入性能测试")
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
from sentence_transformers import SentenceTransformer
|
| 91 |
+
|
| 92 |
+
# 准备测试数据
|
| 93 |
+
test_texts = [
|
| 94 |
+
"Large Language Models are transforming AI",
|
| 95 |
+
"GraphRAG combines knowledge graphs with retrieval",
|
| 96 |
+
"GPU acceleration significantly improves performance",
|
| 97 |
+
"Natural language processing is advancing rapidly",
|
| 98 |
+
] * 250 # 1000个文本
|
| 99 |
+
|
| 100 |
+
print(f"测试数据: {len(test_texts)} 个文本\n")
|
| 101 |
+
|
| 102 |
+
# CPU测试
|
| 103 |
+
print("🔵 CPU嵌入测试...")
|
| 104 |
+
model_cpu = SentenceTransformer(
|
| 105 |
+
'sentence-transformers/all-MiniLM-L6-v2',
|
| 106 |
+
device='cpu'
|
| 107 |
+
)
|
| 108 |
+
start = time.time()
|
| 109 |
+
embeddings_cpu = model_cpu.encode(test_texts, show_progress_bar=False, batch_size=32)
|
| 110 |
+
cpu_time = time.time() - start
|
| 111 |
+
print(f" CPU时间: {cpu_time:.2f}秒")
|
| 112 |
+
print(f" 速度: {len(test_texts)/cpu_time:.1f} 文本/秒")
|
| 113 |
+
|
| 114 |
+
# GPU测试
|
| 115 |
+
if torch.cuda.is_available():
|
| 116 |
+
print("\n🟢 GPU嵌入测试...")
|
| 117 |
+
model_gpu = SentenceTransformer(
|
| 118 |
+
'sentence-transformers/all-MiniLM-L6-v2',
|
| 119 |
+
device='cuda'
|
| 120 |
+
)
|
| 121 |
+
start = time.time()
|
| 122 |
+
embeddings_gpu = model_gpu.encode(test_texts, show_progress_bar=False, batch_size=32)
|
| 123 |
+
gpu_time = time.time() - start
|
| 124 |
+
print(f" GPU时间: {gpu_time:.2f}秒")
|
| 125 |
+
print(f" 速度: {len(test_texts)/gpu_time:.1f} 文本/秒")
|
| 126 |
+
|
| 127 |
+
speedup = cpu_time / gpu_time
|
| 128 |
+
print(f"\n🚀 加速比: {speedup:.2f}x")
|
| 129 |
+
print(f" 节省时间: {cpu_time - gpu_time:.2f}秒")
|
| 130 |
+
|
| 131 |
+
return speedup
|
| 132 |
+
else:
|
| 133 |
+
print("\n⚠️ 跳过GPU测试")
|
| 134 |
+
return 1.0
|
| 135 |
+
|
| 136 |
+
except ImportError:
|
| 137 |
+
print("⚠️ sentence-transformers未安装")
|
| 138 |
+
print(" 安装: pip install sentence-transformers")
|
| 139 |
+
return None
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def monitor_gpu_memory():
|
| 143 |
+
"""监控GPU显存使用"""
|
| 144 |
+
if not torch.cuda.is_available():
|
| 145 |
+
return
|
| 146 |
+
|
| 147 |
+
print_section("💾 GPU显存使用情况")
|
| 148 |
+
|
| 149 |
+
allocated = torch.cuda.memory_allocated(0) / (1024**3)
|
| 150 |
+
reserved = torch.cuda.memory_reserved(0) / (1024**3)
|
| 151 |
+
total = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
| 152 |
+
|
| 153 |
+
print(f"已分配: {allocated:.2f} GB")
|
| 154 |
+
print(f"已保留: {reserved:.2f} GB")
|
| 155 |
+
print(f"总显存: {total:.2f} GB")
|
| 156 |
+
print(f"使用率: {(allocated/total)*100:.1f}%")
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def generate_performance_report(matrix_speedup, embedding_speedup):
|
| 160 |
+
"""生成性能报告"""
|
| 161 |
+
print_section("📈 性能测试总结报告")
|
| 162 |
+
|
| 163 |
+
print("🖥️ 硬件信息:")
|
| 164 |
+
if torch.cuda.is_available():
|
| 165 |
+
print(f" GPU型号: {torch.cuda.get_device_name(0)}")
|
| 166 |
+
print(f" 显存: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.2f} GB")
|
| 167 |
+
print(f" CUDA版本: {torch.version.cuda}")
|
| 168 |
+
else:
|
| 169 |
+
print(" ⚠️ GPU不可用")
|
| 170 |
+
|
| 171 |
+
print(f"\n PyTorch版本: {torch.__version__}")
|
| 172 |
+
print(f" Python版本: {sys.version.split()[0]}")
|
| 173 |
+
|
| 174 |
+
print("\n⚡ 性能测试结果:")
|
| 175 |
+
print(f" 矩阵运算加速: {matrix_speedup:.2f}x")
|
| 176 |
+
if embedding_speedup:
|
| 177 |
+
print(f" 文本嵌入加速: {embedding_speedup:.2f}x")
|
| 178 |
+
|
| 179 |
+
print("\n💡 建议:")
|
| 180 |
+
if torch.cuda.is_available():
|
| 181 |
+
print(" ✅ GPU运行良好!")
|
| 182 |
+
print(" ✅ 建议在Colab上运行完整的GraphRAG索引构建")
|
| 183 |
+
print(" ✅ 预计索引构建时间将缩短 3-5 倍")
|
| 184 |
+
|
| 185 |
+
# 估算时间节省
|
| 186 |
+
if embedding_speedup and embedding_speedup > 1:
|
| 187 |
+
print(f"\n⏱️ 时间节省估算:")
|
| 188 |
+
print(f" 100文档CPU耗时: ~15分钟")
|
| 189 |
+
print(f" 100文档GPU耗时: ~{15/embedding_speedup:.1f}分钟")
|
| 190 |
+
print(f" 节省: ~{15 - 15/embedding_speedup:.1f}分钟")
|
| 191 |
+
else:
|
| 192 |
+
print(" ⚠️ 建议启用GPU以获得最佳性能")
|
| 193 |
+
print(" ⚠️ Colab启用GPU: 运行时 → 更改运行时类型 → GPU")
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def install_dependencies():
|
| 197 |
+
"""安装必要的依赖(仅在Colab中)"""
|
| 198 |
+
try:
|
| 199 |
+
import google.colab
|
| 200 |
+
is_colab = True
|
| 201 |
+
except:
|
| 202 |
+
is_colab = False
|
| 203 |
+
|
| 204 |
+
if is_colab:
|
| 205 |
+
print_section("📦 安装依赖")
|
| 206 |
+
print("检测到Colab环境,安装必要的包...\n")
|
| 207 |
+
|
| 208 |
+
import subprocess
|
| 209 |
+
packages = [
|
| 210 |
+
'sentence-transformers',
|
| 211 |
+
'networkx',
|
| 212 |
+
'python-louvain',
|
| 213 |
+
]
|
| 214 |
+
|
| 215 |
+
for package in packages:
|
| 216 |
+
try:
|
| 217 |
+
__import__(package.replace('-', '_'))
|
| 218 |
+
print(f"✅ {package} 已安装")
|
| 219 |
+
except ImportError:
|
| 220 |
+
print(f"📥 安装 {package}...")
|
| 221 |
+
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', package])
|
| 222 |
+
print(f"✅ {package} 安装完成")
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def main():
|
| 226 |
+
"""主函数"""
|
| 227 |
+
print("\n" + "="*60)
|
| 228 |
+
print("🚀 Google Colab GPU检测和GraphRAG性能测试")
|
| 229 |
+
print("="*60)
|
| 230 |
+
|
| 231 |
+
# 检查是否在Colab中运行
|
| 232 |
+
try:
|
| 233 |
+
import google.colab
|
| 234 |
+
print("\n✅ 运行环境: Google Colab")
|
| 235 |
+
except:
|
| 236 |
+
print("\n⚠️ 警告: 未检测到Colab环境")
|
| 237 |
+
print(" 本脚本专为Google Colab设计")
|
| 238 |
+
|
| 239 |
+
# 安装依赖
|
| 240 |
+
install_dependencies()
|
| 241 |
+
|
| 242 |
+
# 1. GPU检测
|
| 243 |
+
gpu_available = test_gpu_availability()
|
| 244 |
+
|
| 245 |
+
# 2. 矩阵运算性能测试
|
| 246 |
+
matrix_speedup = benchmark_matrix_multiplication(matrix_size=5000)
|
| 247 |
+
|
| 248 |
+
# 3. 文本嵌入性能测试
|
| 249 |
+
embedding_speedup = test_text_embedding_performance()
|
| 250 |
+
|
| 251 |
+
# 4. 显存监控
|
| 252 |
+
if gpu_available:
|
| 253 |
+
monitor_gpu_memory()
|
| 254 |
+
|
| 255 |
+
# 5. 生成报告
|
| 256 |
+
generate_performance_report(matrix_speedup, embedding_speedup)
|
| 257 |
+
|
| 258 |
+
print("\n" + "="*60)
|
| 259 |
+
print("✅ 测试完成!")
|
| 260 |
+
print("="*60)
|
| 261 |
+
|
| 262 |
+
print("\n📚 下一步:")
|
| 263 |
+
print(" 1. 如果GPU测试成功,可以上传完整的adaptive_RAG项目")
|
| 264 |
+
print(" 2. 运行 main_graphrag.py 进行完整的知识图谱构建")
|
| 265 |
+
print(" 3. 享受GPU带来的3-5倍速度提升!")
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
if __name__ == "__main__":
|
| 269 |
+
main()
|
colab_quick_test.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Google Colab一键GPU测试脚本
|
| 4 |
+
复制此文件内容到Colab单元格中直接运行
|
| 5 |
+
|
| 6 |
+
使用方法:
|
| 7 |
+
1. 在Colab中创建新笔记本
|
| 8 |
+
2. 启用GPU (运行时 → 更改运行时类型 → GPU)
|
| 9 |
+
3. 复制并运行此脚本
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
# ============================================================
|
| 13 |
+
# 🔧 自动安装依赖
|
| 14 |
+
# ============================================================
|
| 15 |
+
print("📦 检查并安装依赖...")
|
| 16 |
+
import subprocess
|
| 17 |
+
import sys
|
| 18 |
+
|
| 19 |
+
def install(package):
|
| 20 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", package])
|
| 21 |
+
|
| 22 |
+
# 检查必要的包
|
| 23 |
+
required_packages = {
|
| 24 |
+
'torch': 'torch',
|
| 25 |
+
'sentence_transformers': 'sentence-transformers',
|
| 26 |
+
'networkx': 'networkx',
|
| 27 |
+
'numpy': 'numpy'
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
for import_name, package_name in required_packages.items():
|
| 31 |
+
try:
|
| 32 |
+
__import__(import_name)
|
| 33 |
+
print(f"✅ {package_name} 已安装")
|
| 34 |
+
except ImportError:
|
| 35 |
+
print(f"📥 安装 {package_name}...")
|
| 36 |
+
install(package_name)
|
| 37 |
+
|
| 38 |
+
print("\n" + "="*70)
|
| 39 |
+
print("🚀 Google Colab GPU性能测试 - GraphRAG加速验证")
|
| 40 |
+
print("="*70)
|
| 41 |
+
|
| 42 |
+
# ============================================================
|
| 43 |
+
# 1️⃣ GPU检测
|
| 44 |
+
# ============================================================
|
| 45 |
+
import torch
|
| 46 |
+
import time
|
| 47 |
+
|
| 48 |
+
print("\n" + "="*70)
|
| 49 |
+
print("🔍 步骤1: GPU环境检测")
|
| 50 |
+
print("="*70)
|
| 51 |
+
|
| 52 |
+
cuda_available = torch.cuda.is_available()
|
| 53 |
+
print(f"\n{'✅' if cuda_available else '❌'} CUDA可用: {cuda_available}")
|
| 54 |
+
|
| 55 |
+
if cuda_available:
|
| 56 |
+
print(f" 📊 GPU型号: {torch.cuda.get_device_name(0)}")
|
| 57 |
+
print(f" 💾 显存大小: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.2f} GB")
|
| 58 |
+
print(f" 🔢 CUDA版本: {torch.version.cuda}")
|
| 59 |
+
print(f" 📈 PyTorch版本: {torch.__version__}")
|
| 60 |
+
else:
|
| 61 |
+
print("\n⚠️ GPU未启用!")
|
| 62 |
+
print(" 请按照以下步骤启用GPU:")
|
| 63 |
+
print(" 1. 点击顶部菜单 '运行时'")
|
| 64 |
+
print(" 2. 选择 '更改运行时类型'")
|
| 65 |
+
print(" 3. 硬件加速器选择 'GPU'")
|
| 66 |
+
print(" 4. 点击 '保存'")
|
| 67 |
+
print(" 5. 重新运行此单元格")
|
| 68 |
+
print("\n⚠️ 测试将继续,但GPU相关测试会被跳过")
|
| 69 |
+
|
| 70 |
+
# ============================================================
|
| 71 |
+
# 2️⃣ 矩阵运算性能测试
|
| 72 |
+
# ============================================================
|
| 73 |
+
print("\n" + "="*70)
|
| 74 |
+
print("⚡ 步骤2: 矩阵运算性能测试")
|
| 75 |
+
print("="*70)
|
| 76 |
+
|
| 77 |
+
matrix_size = 5000
|
| 78 |
+
print(f"\n测试配置: {matrix_size}x{matrix_size} 矩阵乘法\n")
|
| 79 |
+
|
| 80 |
+
# CPU测试
|
| 81 |
+
print("🔵 CPU性能测试...")
|
| 82 |
+
a_cpu = torch.randn(matrix_size, matrix_size)
|
| 83 |
+
b_cpu = torch.randn(matrix_size, matrix_size)
|
| 84 |
+
|
| 85 |
+
start = time.time()
|
| 86 |
+
c_cpu = torch.mm(a_cpu, b_cpu)
|
| 87 |
+
cpu_time = time.time() - start
|
| 88 |
+
|
| 89 |
+
print(f" ⏱️ CPU耗时: {cpu_time:.3f}秒")
|
| 90 |
+
|
| 91 |
+
# GPU测试
|
| 92 |
+
if cuda_available:
|
| 93 |
+
print("\n🟢 GPU性能测试...")
|
| 94 |
+
a_gpu = torch.randn(matrix_size, matrix_size).cuda()
|
| 95 |
+
b_gpu = torch.randn(matrix_size, matrix_size).cuda()
|
| 96 |
+
|
| 97 |
+
# 预热
|
| 98 |
+
_ = torch.mm(a_gpu, b_gpu)
|
| 99 |
+
torch.cuda.synchronize()
|
| 100 |
+
|
| 101 |
+
start = time.time()
|
| 102 |
+
c_gpu = torch.mm(a_gpu, b_gpu)
|
| 103 |
+
torch.cuda.synchronize()
|
| 104 |
+
gpu_time = time.time() - start
|
| 105 |
+
|
| 106 |
+
print(f" ⏱️ GPU耗时: {gpu_time:.3f}秒")
|
| 107 |
+
|
| 108 |
+
speedup = cpu_time / gpu_time
|
| 109 |
+
print(f"\n 🚀 性能提升: {speedup:.1f}x")
|
| 110 |
+
print(f" 💡 GPU比CPU快 {speedup:.1f} 倍!")
|
| 111 |
+
|
| 112 |
+
matrix_speedup = speedup
|
| 113 |
+
else:
|
| 114 |
+
print("\n⚠️ 跳过GPU测试")
|
| 115 |
+
matrix_speedup = 1.0
|
| 116 |
+
|
| 117 |
+
# ============================================================
|
| 118 |
+
# 3️⃣ 文本嵌入性能测试
|
| 119 |
+
# ============================================================
|
| 120 |
+
print("\n" + "="*70)
|
| 121 |
+
print("📝 步骤3: 文本嵌入性能测试 (GraphRAG核心组件)")
|
| 122 |
+
print("="*70)
|
| 123 |
+
|
| 124 |
+
try:
|
| 125 |
+
from sentence_transformers import SentenceTransformer
|
| 126 |
+
|
| 127 |
+
# 准备测试数据
|
| 128 |
+
test_texts = [
|
| 129 |
+
"GraphRAG combines knowledge graphs with retrieval augmented generation",
|
| 130 |
+
"GPU acceleration significantly improves machine learning performance",
|
| 131 |
+
"Large language models benefit from efficient embedding computation",
|
| 132 |
+
"Knowledge graph construction requires entity and relation extraction",
|
| 133 |
+
] * 250 # 1000条文本
|
| 134 |
+
|
| 135 |
+
print(f"\n测试配置: {len(test_texts)}条文本嵌入\n")
|
| 136 |
+
|
| 137 |
+
# CPU嵌入
|
| 138 |
+
print("🔵 CPU嵌入测试...")
|
| 139 |
+
model_cpu = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device='cpu')
|
| 140 |
+
|
| 141 |
+
start = time.time()
|
| 142 |
+
embeddings_cpu = model_cpu.encode(test_texts, show_progress_bar=False, batch_size=32)
|
| 143 |
+
cpu_emb_time = time.time() - start
|
| 144 |
+
|
| 145 |
+
print(f" ⏱️ CPU耗时: {cpu_emb_time:.2f}秒")
|
| 146 |
+
print(f" 📊 处理速度: {len(test_texts)/cpu_emb_time:.1f} 文本/秒")
|
| 147 |
+
|
| 148 |
+
# GPU嵌入
|
| 149 |
+
if cuda_available:
|
| 150 |
+
print("\n🟢 GPU嵌入测试...")
|
| 151 |
+
model_gpu = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device='cuda')
|
| 152 |
+
|
| 153 |
+
start = time.time()
|
| 154 |
+
embeddings_gpu = model_gpu.encode(test_texts, show_progress_bar=False, batch_size=32)
|
| 155 |
+
gpu_emb_time = time.time() - start
|
| 156 |
+
|
| 157 |
+
print(f" ⏱️ GPU耗时: {gpu_emb_time:.2f}秒")
|
| 158 |
+
print(f" 📊 处理速度: {len(test_texts)/gpu_emb_time:.1f} 文本/秒")
|
| 159 |
+
|
| 160 |
+
emb_speedup = cpu_emb_time / gpu_emb_time
|
| 161 |
+
print(f"\n 🚀 性能提升: {emb_speedup:.1f}x")
|
| 162 |
+
print(f" ⏱️ 节省时间: {cpu_emb_time - gpu_emb_time:.2f}秒")
|
| 163 |
+
else:
|
| 164 |
+
print("\n⚠️ 跳过GPU测试")
|
| 165 |
+
emb_speedup = 1.0
|
| 166 |
+
|
| 167 |
+
except ImportError:
|
| 168 |
+
print("\n⚠️ sentence-transformers未安装,跳过此测试")
|
| 169 |
+
emb_speedup = None
|
| 170 |
+
|
| 171 |
+
# ============================================================
|
| 172 |
+
# 4️⃣ GraphRAG场景模拟
|
| 173 |
+
# ============================================================
|
| 174 |
+
print("\n" + "="*70)
|
| 175 |
+
print("🔍 步骤4: GraphRAG实际场景模拟")
|
| 176 |
+
print("="*70)
|
| 177 |
+
|
| 178 |
+
if cuda_available and emb_speedup:
|
| 179 |
+
print("\n模拟GraphRAG索引构建过程...\n")
|
| 180 |
+
|
| 181 |
+
# 假设100个文档块的索引构建
|
| 182 |
+
documents_count = 100
|
| 183 |
+
|
| 184 |
+
# 实体提取时间 (每个文档约1秒)
|
| 185 |
+
entity_extraction_time = documents_count * 1.0
|
| 186 |
+
|
| 187 |
+
# 文本嵌入时间 (基于实际测试)
|
| 188 |
+
# 假设每个文档平均产生10个实体,共1000个实体需要嵌入
|
| 189 |
+
entities_count = documents_count * 10
|
| 190 |
+
|
| 191 |
+
cpu_total_time = entity_extraction_time + (entities_count / (len(test_texts)/cpu_emb_time))
|
| 192 |
+
gpu_total_time = entity_extraction_time + (entities_count / (len(test_texts)/gpu_emb_time))
|
| 193 |
+
|
| 194 |
+
print(f"📊 场景: {documents_count}个文档的GraphRAG索引构建\n")
|
| 195 |
+
print(f"🔵 CPU预计时间:")
|
| 196 |
+
print(f" - 实体提取: {entity_extraction_time/60:.1f}分钟")
|
| 197 |
+
print(f" - 向量嵌入: {(entities_count / (len(test_texts)/cpu_emb_time))/60:.1f}分钟")
|
| 198 |
+
print(f" - 总计: {cpu_total_time/60:.1f}分钟")
|
| 199 |
+
|
| 200 |
+
print(f"\n🟢 GPU预计时间:")
|
| 201 |
+
print(f" - 实体提取: {entity_extraction_time/60:.1f}分钟 (相同)")
|
| 202 |
+
print(f" - 向量嵌入: {(entities_count / (len(test_texts)/gpu_emb_time))/60:.1f}分钟")
|
| 203 |
+
print(f" - 总计: {gpu_total_time/60:.1f}分钟")
|
| 204 |
+
|
| 205 |
+
total_speedup = cpu_total_time / gpu_total_time
|
| 206 |
+
time_saved = (cpu_total_time - gpu_total_time) / 60
|
| 207 |
+
|
| 208 |
+
print(f"\n🚀 整体加速: {total_speedup:.1f}x")
|
| 209 |
+
print(f"⏱️ 节省时间: {time_saved:.1f}分钟")
|
| 210 |
+
|
| 211 |
+
# ============================================================
|
| 212 |
+
# 5️⃣ GPU显存监控
|
| 213 |
+
# ============================================================
|
| 214 |
+
if cuda_available:
|
| 215 |
+
print("\n" + "="*70)
|
| 216 |
+
print("💾 步骤5: GPU显存使用监控")
|
| 217 |
+
print("="*70)
|
| 218 |
+
|
| 219 |
+
allocated = torch.cuda.memory_allocated(0) / (1024**3)
|
| 220 |
+
reserved = torch.cuda.memory_reserved(0) / (1024**3)
|
| 221 |
+
total = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
| 222 |
+
|
| 223 |
+
print(f"\n 已分配: {allocated:.2f} GB")
|
| 224 |
+
print(f" 已保留: {reserved:.2f} GB")
|
| 225 |
+
print(f" 总显存: {total:.2f} GB")
|
| 226 |
+
print(f" 使用率: {(allocated/total)*100:.1f}%")
|
| 227 |
+
|
| 228 |
+
# ============================================================
|
| 229 |
+
# 6️⃣ 性能总结
|
| 230 |
+
# ============================================================
|
| 231 |
+
print("\n" + "="*70)
|
| 232 |
+
print("📈 最终性能报告")
|
| 233 |
+
print("="*70)
|
| 234 |
+
|
| 235 |
+
print("\n🖥️ 硬件配置:")
|
| 236 |
+
if cuda_available:
|
| 237 |
+
print(f" GPU: {torch.cuda.get_device_name(0)}")
|
| 238 |
+
print(f" 显存: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.1f} GB")
|
| 239 |
+
print(f" CUDA: {torch.version.cuda}")
|
| 240 |
+
else:
|
| 241 |
+
print(" ⚠️ GPU未启用")
|
| 242 |
+
|
| 243 |
+
print(f"\n⚡ 性能测试结果:")
|
| 244 |
+
print(f" 矩阵运算加速: {matrix_speedup:.1f}x")
|
| 245 |
+
if emb_speedup:
|
| 246 |
+
print(f" 文本嵌入加速: {emb_speedup:.1f}x")
|
| 247 |
+
if cuda_available:
|
| 248 |
+
print(f" GraphRAG整体加速: {total_speedup:.1f}x")
|
| 249 |
+
|
| 250 |
+
print("\n💡 结论和建议:")
|
| 251 |
+
if cuda_available:
|
| 252 |
+
print(" ✅ GPU性能测试成功!")
|
| 253 |
+
print(" ✅ 强烈建议在Colab GPU环境运行GraphRAG")
|
| 254 |
+
print(f" ✅ 预计可节省 {time_saved:.0f}+ 分钟的索引构建时间")
|
| 255 |
+
print("\n📚 下一步:")
|
| 256 |
+
print(" 1. 上传adaptive_RAG项目文件到Colab")
|
| 257 |
+
print(" 2. 运行 main_graphrag.py 构建完整知识图谱")
|
| 258 |
+
print(" 3. 下载结果到本地使用")
|
| 259 |
+
else:
|
| 260 |
+
print(" ⚠️ 请启用GPU以获得最佳性能")
|
| 261 |
+
print(" ⚠️ 路径: 运行时 → 更改运行时类型 → GPU")
|
| 262 |
+
|
| 263 |
+
print("\n" + "="*70)
|
| 264 |
+
print("✅ 测试完成! 感谢使用GraphRAG GPU测试工具")
|
| 265 |
+
print("="*70)
|
| 266 |
+
|
| 267 |
+
# ============================================================
|
| 268 |
+
# 7️⃣ 可选: 显示nvidia-smi
|
| 269 |
+
# ============================================================
|
| 270 |
+
if cuda_available:
|
| 271 |
+
print("\n📊 nvidia-smi 详细信息:")
|
| 272 |
+
print("="*70)
|
| 273 |
+
import subprocess
|
| 274 |
+
try:
|
| 275 |
+
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
|
| 276 |
+
print(result.stdout)
|
| 277 |
+
except:
|
| 278 |
+
print("⚠️ 无法执行nvidia-smi命令")
|
colab_setup_and_run.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Google Colab环境下的GraphRAG完整运行脚本
|
| 4 |
+
解决Ollama服务启动和GraphRAG运行的问题
|
| 5 |
+
|
| 6 |
+
使用方法:
|
| 7 |
+
1. 在Colab中启用GPU
|
| 8 |
+
2. 复制此文件到Colab
|
| 9 |
+
3. 运行: !python colab_setup_and_run.py
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import time
|
| 15 |
+
import subprocess
|
| 16 |
+
import signal
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
print("="*70)
|
| 20 |
+
print("🚀 GraphRAG Colab 自动化部署脚本")
|
| 21 |
+
print("="*70)
|
| 22 |
+
|
| 23 |
+
# ============================================================
|
| 24 |
+
# 1️⃣ 检测Colab环境
|
| 25 |
+
# ============================================================
|
| 26 |
+
def check_colab_environment():
|
| 27 |
+
"""检测是否在Colab环境中"""
|
| 28 |
+
try:
|
| 29 |
+
import google.colab
|
| 30 |
+
print("\n✅ 运行环境: Google Colab")
|
| 31 |
+
return True
|
| 32 |
+
except ImportError:
|
| 33 |
+
print("\n⚠️ 警告: 未检测到Colab环境")
|
| 34 |
+
print(" 本脚本为Colab优化,在其他环境可能需要调整")
|
| 35 |
+
return False
|
| 36 |
+
|
| 37 |
+
# ============================================================
|
| 38 |
+
# 2️⃣ 安装Ollama
|
| 39 |
+
# ============================================================
|
| 40 |
+
def install_ollama():
|
| 41 |
+
"""在Colab中安装Ollama"""
|
| 42 |
+
print("\n" + "="*70)
|
| 43 |
+
print("📦 步骤1: 安装Ollama")
|
| 44 |
+
print("="*70)
|
| 45 |
+
|
| 46 |
+
# 检查是否已安装
|
| 47 |
+
if os.path.exists("/usr/local/bin/ollama"):
|
| 48 |
+
print("✅ Ollama已安装")
|
| 49 |
+
return True
|
| 50 |
+
|
| 51 |
+
print("\n📥 下载并安装Ollama...")
|
| 52 |
+
try:
|
| 53 |
+
# 下载Ollama安装脚本
|
| 54 |
+
subprocess.run(
|
| 55 |
+
["curl", "-fsSL", "https://ollama.com/install.sh", "-o", "/tmp/install_ollama.sh"],
|
| 56 |
+
check=True,
|
| 57 |
+
capture_output=True
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# 执行安装
|
| 61 |
+
subprocess.run(
|
| 62 |
+
["sh", "/tmp/install_ollama.sh"],
|
| 63 |
+
check=True,
|
| 64 |
+
capture_output=True
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
print("✅ Ollama安装成功")
|
| 68 |
+
return True
|
| 69 |
+
|
| 70 |
+
except subprocess.CalledProcessError as e:
|
| 71 |
+
print(f"❌ Ollama安装失败: {e}")
|
| 72 |
+
return False
|
| 73 |
+
|
| 74 |
+
# ============================================================
|
| 75 |
+
# 3️⃣ 后台启动Ollama服务
|
| 76 |
+
# ============================================================
|
| 77 |
+
def start_ollama_service():
|
| 78 |
+
"""在后台启动Ollama服务"""
|
| 79 |
+
print("\n" + "="*70)
|
| 80 |
+
print("🔧 步骤2: 启动Ollama服务")
|
| 81 |
+
print("="*70)
|
| 82 |
+
|
| 83 |
+
print("\n🔄 在后台启动Ollama服务...")
|
| 84 |
+
|
| 85 |
+
# 方法1: 使用subprocess后台运行
|
| 86 |
+
try:
|
| 87 |
+
# 启动Ollama服务(后台)
|
| 88 |
+
ollama_process = subprocess.Popen(
|
| 89 |
+
["ollama", "serve"],
|
| 90 |
+
stdout=subprocess.PIPE,
|
| 91 |
+
stderr=subprocess.PIPE,
|
| 92 |
+
preexec_fn=os.setpgrp # 创建新的进程组
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# 等待服务启动
|
| 96 |
+
print("⏳ 等待Ollama服务启动...")
|
| 97 |
+
time.sleep(5)
|
| 98 |
+
|
| 99 |
+
# 检查服务是否运行
|
| 100 |
+
try:
|
| 101 |
+
result = subprocess.run(
|
| 102 |
+
["curl", "-s", "http://localhost:11434/api/tags"],
|
| 103 |
+
capture_output=True,
|
| 104 |
+
timeout=3
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
if result.returncode == 0:
|
| 108 |
+
print("✅ Ollama服务已启动 (PID: {})".format(ollama_process.pid))
|
| 109 |
+
|
| 110 |
+
# 保存进程ID以便后续管理
|
| 111 |
+
with open("/tmp/ollama.pid", "w") as f:
|
| 112 |
+
f.write(str(ollama_process.pid))
|
| 113 |
+
|
| 114 |
+
return ollama_process
|
| 115 |
+
else:
|
| 116 |
+
print("⚠️ 服务启动可能有问题,继续尝试...")
|
| 117 |
+
|
| 118 |
+
except subprocess.TimeoutExpired:
|
| 119 |
+
print("⚠️ 服务检查超时,但进程已启动")
|
| 120 |
+
return ollama_process
|
| 121 |
+
|
| 122 |
+
except Exception as e:
|
| 123 |
+
print(f"❌ 启动Ollama失败: {e}")
|
| 124 |
+
return None
|
| 125 |
+
|
| 126 |
+
# ============================================================
|
| 127 |
+
# 4️⃣ 下载Mistral模型
|
| 128 |
+
# ============================================================
|
| 129 |
+
def pull_mistral_model():
|
| 130 |
+
"""下载Mistral模型"""
|
| 131 |
+
print("\n" + "="*70)
|
| 132 |
+
print("📥 步骤3: 下载Mistral模型")
|
| 133 |
+
print("="*70)
|
| 134 |
+
|
| 135 |
+
print("\n🔄 拉取mistral模型(这可能需要几分钟)...")
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
# 检查模型是否已存在
|
| 139 |
+
result = subprocess.run(
|
| 140 |
+
["ollama", "list"],
|
| 141 |
+
capture_output=True,
|
| 142 |
+
text=True,
|
| 143 |
+
timeout=10
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
if "mistral" in result.stdout:
|
| 147 |
+
print("✅ Mistral模型已存在")
|
| 148 |
+
return True
|
| 149 |
+
|
| 150 |
+
# 下载模型
|
| 151 |
+
print("📥 开始下载Mistral模型...")
|
| 152 |
+
process = subprocess.Popen(
|
| 153 |
+
["ollama", "pull", "mistral"],
|
| 154 |
+
stdout=subprocess.PIPE,
|
| 155 |
+
stderr=subprocess.STDOUT,
|
| 156 |
+
text=True
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# 实时显示下载进度
|
| 160 |
+
for line in process.stdout:
|
| 161 |
+
print(f" {line.strip()}")
|
| 162 |
+
|
| 163 |
+
process.wait()
|
| 164 |
+
|
| 165 |
+
if process.returncode == 0:
|
| 166 |
+
print("✅ Mistral模型下载完成")
|
| 167 |
+
return True
|
| 168 |
+
else:
|
| 169 |
+
print("❌ 模型下载失败")
|
| 170 |
+
return False
|
| 171 |
+
|
| 172 |
+
except Exception as e:
|
| 173 |
+
print(f"❌ 下载Mistral模型失败: {e}")
|
| 174 |
+
return False
|
| 175 |
+
|
| 176 |
+
# ============================================================
|
| 177 |
+
# 5️⃣ 安装Python依赖
|
| 178 |
+
# ============================================================
|
| 179 |
+
def install_python_dependencies():
|
| 180 |
+
"""安装GraphRAG所需的Python包"""
|
| 181 |
+
print("\n" + "="*70)
|
| 182 |
+
print("📦 步骤4: 安装Python依赖")
|
| 183 |
+
print("="*70)
|
| 184 |
+
|
| 185 |
+
packages = [
|
| 186 |
+
"langchain",
|
| 187 |
+
"langchain-community",
|
| 188 |
+
"langchain-core",
|
| 189 |
+
"langgraph",
|
| 190 |
+
"langchain-ollama",
|
| 191 |
+
"chromadb",
|
| 192 |
+
"sentence-transformers",
|
| 193 |
+
"tiktoken",
|
| 194 |
+
"beautifulsoup4",
|
| 195 |
+
"requests",
|
| 196 |
+
"tavily-python",
|
| 197 |
+
"python-dotenv",
|
| 198 |
+
"networkx",
|
| 199 |
+
"python-louvain",
|
| 200 |
+
"torch",
|
| 201 |
+
"transformers"
|
| 202 |
+
]
|
| 203 |
+
|
| 204 |
+
print("\n📥 安装必要的Python包...")
|
| 205 |
+
for package in packages:
|
| 206 |
+
try:
|
| 207 |
+
__import__(package.replace("-", "_"))
|
| 208 |
+
print(f"✅ {package} 已安装")
|
| 209 |
+
except ImportError:
|
| 210 |
+
print(f"📥 安装 {package}...")
|
| 211 |
+
subprocess.run(
|
| 212 |
+
[sys.executable, "-m", "pip", "install", "-q", package],
|
| 213 |
+
check=True
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
print("\n✅ 所有依赖安装完成")
|
| 217 |
+
|
| 218 |
+
# ============================================================
|
| 219 |
+
# 6️⃣ 配置环境变量
|
| 220 |
+
# ============================================================
|
| 221 |
+
def setup_environment():
|
| 222 |
+
"""配置环境变量"""
|
| 223 |
+
print("\n" + "="*70)
|
| 224 |
+
print("🔑 步骤5: 配置环境变量")
|
| 225 |
+
print("="*70)
|
| 226 |
+
|
| 227 |
+
# 检查.env文件
|
| 228 |
+
if os.path.exists(".env"):
|
| 229 |
+
print("\n✅ 发现.env文件,加载配置...")
|
| 230 |
+
from dotenv import load_dotenv
|
| 231 |
+
load_dotenv()
|
| 232 |
+
else:
|
| 233 |
+
print("\n⚠️ 未找到.env文件")
|
| 234 |
+
|
| 235 |
+
# 交互式输入API密钥
|
| 236 |
+
if "TAVILY_API_KEY" not in os.environ:
|
| 237 |
+
from getpass import getpass
|
| 238 |
+
api_key = getpass("请输入TAVILY_API_KEY (或按Enter跳过): ")
|
| 239 |
+
if api_key:
|
| 240 |
+
os.environ["TAVILY_API_KEY"] = api_key
|
| 241 |
+
print("✅ TAVILY_API_KEY已设置")
|
| 242 |
+
else:
|
| 243 |
+
print("⚠️ 跳过TAVILY_API_KEY设置(网络搜索功能将不可用)")
|
| 244 |
+
|
| 245 |
+
print("\n📋 当前环境变量:")
|
| 246 |
+
print(f" TAVILY_API_KEY: {'已设置' if os.environ.get('TAVILY_API_KEY') else '未设置'}")
|
| 247 |
+
|
| 248 |
+
# ============================================================
|
| 249 |
+
# 7️⃣ 运行GraphRAG
|
| 250 |
+
# ============================================================
|
| 251 |
+
def run_graphrag():
|
| 252 |
+
"""运行GraphRAG主程序"""
|
| 253 |
+
print("\n" + "="*70)
|
| 254 |
+
print("🚀 步骤6: 运行GraphRAG")
|
| 255 |
+
print("="*70)
|
| 256 |
+
|
| 257 |
+
# 检查main_graphrag.py是否存在
|
| 258 |
+
if not os.path.exists("main_graphrag.py"):
|
| 259 |
+
print("\n❌ 未找到main_graphrag.py文件")
|
| 260 |
+
print(" 请确保已上传项目文件到Colab")
|
| 261 |
+
return False
|
| 262 |
+
|
| 263 |
+
print("\n🔄 启动GraphRAG索引构建...\n")
|
| 264 |
+
|
| 265 |
+
try:
|
| 266 |
+
# 运行GraphRAG
|
| 267 |
+
result = subprocess.run(
|
| 268 |
+
[sys.executable, "main_graphrag.py"],
|
| 269 |
+
capture_output=False, # 实时输出
|
| 270 |
+
text=True
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
if result.returncode == 0:
|
| 274 |
+
print("\n✅ GraphRAG运行成功!")
|
| 275 |
+
return True
|
| 276 |
+
else:
|
| 277 |
+
print(f"\n❌ GraphRAG运行失败 (返回码: {result.returncode})")
|
| 278 |
+
return False
|
| 279 |
+
|
| 280 |
+
except KeyboardInterrupt:
|
| 281 |
+
print("\n⚠️ 用户中断执行")
|
| 282 |
+
return False
|
| 283 |
+
except Exception as e:
|
| 284 |
+
print(f"\n❌ 运行GraphRAG时出错: {e}")
|
| 285 |
+
return False
|
| 286 |
+
|
| 287 |
+
# ============================================================
|
| 288 |
+
# 8️⃣ 清理函数
|
| 289 |
+
# ============================================================
|
| 290 |
+
def cleanup():
|
| 291 |
+
"""清理后台进程"""
|
| 292 |
+
print("\n" + "="*70)
|
| 293 |
+
print("🧹 清理后台进程")
|
| 294 |
+
print("="*70)
|
| 295 |
+
|
| 296 |
+
# 停止Ollama服务
|
| 297 |
+
if os.path.exists("/tmp/ollama.pid"):
|
| 298 |
+
try:
|
| 299 |
+
with open("/tmp/ollama.pid", "r") as f:
|
| 300 |
+
pid = int(f.read().strip())
|
| 301 |
+
|
| 302 |
+
os.kill(pid, signal.SIGTERM)
|
| 303 |
+
print(f"✅ Ollama服务已停止 (PID: {pid})")
|
| 304 |
+
os.remove("/tmp/ollama.pid")
|
| 305 |
+
|
| 306 |
+
except Exception as e:
|
| 307 |
+
print(f"⚠️ 停止Ollama服务失败: {e}")
|
| 308 |
+
|
| 309 |
+
# ============================================================
|
| 310 |
+
# 主函数
|
| 311 |
+
# ============================================================
|
| 312 |
+
def main():
|
| 313 |
+
"""主执行流程"""
|
| 314 |
+
ollama_process = None
|
| 315 |
+
|
| 316 |
+
try:
|
| 317 |
+
# 1. 检测环境
|
| 318 |
+
is_colab = check_colab_environment()
|
| 319 |
+
|
| 320 |
+
# 2. 安装Ollama
|
| 321 |
+
if not install_ollama():
|
| 322 |
+
print("\n❌ Ollama安装失败,无法继续")
|
| 323 |
+
return
|
| 324 |
+
|
| 325 |
+
# 3. 启动Ollama服务
|
| 326 |
+
ollama_process = start_ollama_service()
|
| 327 |
+
if not ollama_process:
|
| 328 |
+
print("\n❌ Ollama服务启动失败,无法继续")
|
| 329 |
+
return
|
| 330 |
+
|
| 331 |
+
# 4. 下载模型
|
| 332 |
+
if not pull_mistral_model():
|
| 333 |
+
print("\n❌ Mistral模型下载失败,无法继续")
|
| 334 |
+
return
|
| 335 |
+
|
| 336 |
+
# 5. 安装Python依赖
|
| 337 |
+
install_python_dependencies()
|
| 338 |
+
|
| 339 |
+
# 6. 配置环境
|
| 340 |
+
setup_environment()
|
| 341 |
+
|
| 342 |
+
# 7. 运行GraphRAG
|
| 343 |
+
success = run_graphrag()
|
| 344 |
+
|
| 345 |
+
if success:
|
| 346 |
+
print("\n" + "="*70)
|
| 347 |
+
print("✅ 所有任务完成!")
|
| 348 |
+
print("="*70)
|
| 349 |
+
|
| 350 |
+
print("\n📊 生成的文件:")
|
| 351 |
+
if os.path.exists("data/knowledge_graph.json"):
|
| 352 |
+
print(" ✅ data/knowledge_graph.json")
|
| 353 |
+
|
| 354 |
+
# 提供下载选项
|
| 355 |
+
if is_colab:
|
| 356 |
+
print("\n💾 下载结果:")
|
| 357 |
+
print(" from google.colab import files")
|
| 358 |
+
print(" files.download('data/knowledge_graph.json')")
|
| 359 |
+
|
| 360 |
+
except KeyboardInterrupt:
|
| 361 |
+
print("\n\n⚠️ 用户中断执行")
|
| 362 |
+
|
| 363 |
+
except Exception as e:
|
| 364 |
+
print(f"\n❌ 执行过程中出错: {e}")
|
| 365 |
+
import traceback
|
| 366 |
+
traceback.print_exc()
|
| 367 |
+
|
| 368 |
+
finally:
|
| 369 |
+
# 清理
|
| 370 |
+
print("\n⚠️ 注意: Ollama服务仍在后台运行")
|
| 371 |
+
print(" 如需停止: !pkill -f 'ollama serve'")
|
| 372 |
+
print(" 或运行: cleanup()")
|
| 373 |
+
|
| 374 |
+
if __name__ == "__main__":
|
| 375 |
+
main()
|
config.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
配置和环境设置模块
|
| 3 |
+
包含API密钥管理、模型配置和URL配置
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import getpass
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
# 尝试加载.env文件,如果没有安装dotenv则跳过
|
| 10 |
+
try:
|
| 11 |
+
from dotenv import load_dotenv
|
| 12 |
+
load_dotenv()
|
| 13 |
+
print("✅ .env文件已加载")
|
| 14 |
+
except ImportError:
|
| 15 |
+
print("⚠️ python-dotenv未安装,将使用系统环境变量")
|
| 16 |
+
print("提示:运行 'pip install python-dotenv' 来安装")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _set_env(var: str):
|
| 20 |
+
"""设置环境变量,优先从.env文件读取,如果不存在则提示用户输入"""
|
| 21 |
+
if not os.environ.get(var):
|
| 22 |
+
os.environ[var] = getpass.getpass(f"{var}: ")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def setup_environment():
|
| 26 |
+
"""设置所有必需的环境变量"""
|
| 27 |
+
_set_env("TAVILY_API_KEY")
|
| 28 |
+
# 不再需要NOMIC_API_KEY,使用HuggingFace本地嵌入
|
| 29 |
+
|
| 30 |
+
# 验证API密钥是否已设置
|
| 31 |
+
tavily_key = os.environ.get("TAVILY_API_KEY")
|
| 32 |
+
|
| 33 |
+
if tavily_key:
|
| 34 |
+
print("✅ TAVILY_API_KEY 已从环境变量中加载")
|
| 35 |
+
else:
|
| 36 |
+
print("⚠️ TAVILY_API_KEY 未找到")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# 模型配置
|
| 40 |
+
LOCAL_LLM = "mistral"
|
| 41 |
+
|
| 42 |
+
# 知识库URL配置
|
| 43 |
+
KNOWLEDGE_BASE_URLS = [
|
| 44 |
+
"https://lilianweng.github.io/posts/2023-06-23-agent/",
|
| 45 |
+
"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
|
| 46 |
+
"https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
# 文档分块配置
|
| 50 |
+
CHUNK_SIZE = 250
|
| 51 |
+
CHUNK_OVERLAP = 0
|
| 52 |
+
|
| 53 |
+
# 向量数据库配置
|
| 54 |
+
COLLECTION_NAME = "rag-chroma"
|
| 55 |
+
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" # HuggingFace嵌入模型
|
| 56 |
+
|
| 57 |
+
# 搜索配置
|
| 58 |
+
WEB_SEARCH_RESULTS_COUNT = 3
|
| 59 |
+
|
| 60 |
+
# GraphRAG配置
|
| 61 |
+
ENABLE_GRAPHRAG = True # 是否启用GraphRAG功能
|
| 62 |
+
GRAPHRAG_INDEX_PATH = "./data/knowledge_graph.json" # 图谱索引保存路径
|
| 63 |
+
GRAPHRAG_COMMUNITY_ALGORITHM = "louvain" # 社区检测算法: louvain, greedy, label_propagation
|
| 64 |
+
GRAPHRAG_MAX_HOPS = 2 # 本地查询最大跳数
|
| 65 |
+
GRAPHRAG_TOP_K_COMMUNITIES = 5 # 全局查询使用的社区数量
|
| 66 |
+
GRAPHRAG_BATCH_SIZE = 10 # 实体提取批处理大小
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def get_api_keys():
|
| 70 |
+
"""获取API密钥并返回字典"""
|
| 71 |
+
return {
|
| 72 |
+
"tavily": os.environ.get("TAVILY_API_KEY")
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def validate_api_keys():
|
| 77 |
+
"""验证API密钥是否已设置"""
|
| 78 |
+
keys = get_api_keys()
|
| 79 |
+
missing_keys = []
|
| 80 |
+
|
| 81 |
+
if not keys["tavily"]:
|
| 82 |
+
missing_keys.append("TAVILY_API_KEY")
|
| 83 |
+
|
| 84 |
+
if missing_keys:
|
| 85 |
+
raise ValueError(f"缺少必需的API密钥: {', '.join(missing_keys)}\n请在.env文件中设置这些密钥")
|
| 86 |
+
|
| 87 |
+
return True
|
deploy_gpu.sh
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# GPU部署脚本 - 一键部署自适应RAG系统到Linux RTX 4090环境
|
| 3 |
+
|
| 4 |
+
set -e # 遇到错误立即退出
|
| 5 |
+
|
| 6 |
+
echo "🚀 开始部署自适应RAG系统到GPU环境..."
|
| 7 |
+
|
| 8 |
+
# 颜色定义
|
| 9 |
+
RED='\033[0;31m'
|
| 10 |
+
GREEN='\033[0;32m'
|
| 11 |
+
YELLOW='\033[1;33m'
|
| 12 |
+
NC='\033[0m' # No Color
|
| 13 |
+
|
| 14 |
+
# 检查是否为root用户
|
| 15 |
+
if [[ $EUID -eq 0 ]]; then
|
| 16 |
+
echo -e "${RED}请不要使用root用户运行此脚本${NC}"
|
| 17 |
+
exit 1
|
| 18 |
+
fi
|
| 19 |
+
|
| 20 |
+
# 检查GPU
|
| 21 |
+
check_gpu() {
|
| 22 |
+
echo "🔍 检查GPU环境..."
|
| 23 |
+
if command -v nvidia-smi &> /dev/null; then
|
| 24 |
+
echo -e "${GREEN}✅ 发现NVIDIA GPU:${NC}"
|
| 25 |
+
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
|
| 26 |
+
else
|
| 27 |
+
echo -e "${RED}❌ 未发现NVIDIA GPU或驱动未安装${NC}"
|
| 28 |
+
exit 1
|
| 29 |
+
fi
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
# 检查CUDA
|
| 33 |
+
check_cuda() {
|
| 34 |
+
echo "🔍 检查CUDA环境..."
|
| 35 |
+
if command -v nvcc &> /dev/null; then
|
| 36 |
+
echo -e "${GREEN}✅ CUDA版本:${NC}"
|
| 37 |
+
nvcc --version | grep "release"
|
| 38 |
+
else
|
| 39 |
+
echo -e "${YELLOW}⚠️ CUDA未安装或未添加到PATH${NC}"
|
| 40 |
+
fi
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
# 安装Docker
|
| 44 |
+
install_docker() {
|
| 45 |
+
if ! command -v docker &> /dev/null; then
|
| 46 |
+
echo "📦 安装Docker..."
|
| 47 |
+
curl -fsSL https://get.docker.com -o get-docker.sh
|
| 48 |
+
sudo sh get-docker.sh
|
| 49 |
+
sudo usermod -aG docker $USER
|
| 50 |
+
rm get-docker.sh
|
| 51 |
+
echo -e "${GREEN}✅ Docker安装完成${NC}"
|
| 52 |
+
else
|
| 53 |
+
echo -e "${GREEN}✅ Docker已安装${NC}"
|
| 54 |
+
fi
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
# 安装NVIDIA Container Toolkit
|
| 58 |
+
install_nvidia_docker() {
|
| 59 |
+
if ! docker run --rm --gpus all nvidia/cuda:11.0.3-base-ubuntu20.04 nvidia-smi &> /dev/null; then
|
| 60 |
+
echo "🐳 安装NVIDIA Container Toolkit..."
|
| 61 |
+
distribution=$(. /etc/os-release;echo $ID$VERSION_ID)
|
| 62 |
+
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg
|
| 63 |
+
curl -s -L https://nvidia.github.io/libnvidia-container/$distribution/libnvidia-container.list | \
|
| 64 |
+
sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
|
| 65 |
+
sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
|
| 66 |
+
|
| 67 |
+
sudo apt-get update
|
| 68 |
+
sudo apt-get install -y nvidia-container-toolkit
|
| 69 |
+
sudo systemctl restart docker
|
| 70 |
+
echo -e "${GREEN}✅ NVIDIA Container Toolkit安装完成${NC}"
|
| 71 |
+
else
|
| 72 |
+
echo -e "${GREEN}✅ NVIDIA Container Toolkit已配置${NC}"
|
| 73 |
+
fi
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
# 创建环境配置
|
| 77 |
+
setup_env() {
|
| 78 |
+
echo "⚙️ 配置环境变量..."
|
| 79 |
+
if [ ! -f .env ]; then
|
| 80 |
+
cp .env.example .env
|
| 81 |
+
echo -e "${YELLOW}⚠️ 请编辑 .env 文件并设置您的API密钥${NC}"
|
| 82 |
+
echo " - TAVILY_API_KEY: 从 https://tavily.com/ 获取"
|
| 83 |
+
read -p "按回车键继续..."
|
| 84 |
+
fi
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
# 选择部署方式
|
| 88 |
+
choose_deployment() {
|
| 89 |
+
echo "🎯 选择部署方式:"
|
| 90 |
+
echo "1) Docker Compose部署 (推荐)"
|
| 91 |
+
echo "2) 直接Python部署"
|
| 92 |
+
read -p "请选择 (1-2): " choice
|
| 93 |
+
|
| 94 |
+
case $choice in
|
| 95 |
+
1)
|
| 96 |
+
deploy_docker
|
| 97 |
+
;;
|
| 98 |
+
2)
|
| 99 |
+
deploy_python
|
| 100 |
+
;;
|
| 101 |
+
*)
|
| 102 |
+
echo -e "${RED}无效选择${NC}"
|
| 103 |
+
exit 1
|
| 104 |
+
;;
|
| 105 |
+
esac
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
# Docker部署
|
| 109 |
+
deploy_docker() {
|
| 110 |
+
echo "🐳 使用Docker Compose部署..."
|
| 111 |
+
|
| 112 |
+
# 安装docker-compose
|
| 113 |
+
if ! command -v docker-compose &> /dev/null; then
|
| 114 |
+
echo "安装Docker Compose..."
|
| 115 |
+
sudo curl -L "https://github.com/docker/compose/releases/download/v2.20.0/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose
|
| 116 |
+
sudo chmod +x /usr/local/bin/docker-compose
|
| 117 |
+
fi
|
| 118 |
+
|
| 119 |
+
# 构建并启动服务
|
| 120 |
+
echo "构建镜像..."
|
| 121 |
+
docker-compose -f docker-compose.gpu.yml build
|
| 122 |
+
|
| 123 |
+
echo "启动服务..."
|
| 124 |
+
docker-compose -f docker-compose.gpu.yml up -d
|
| 125 |
+
|
| 126 |
+
echo -e "${GREEN}✅ Docker部署完成!${NC}"
|
| 127 |
+
echo "访问: http://localhost:8000"
|
| 128 |
+
echo "监控: http://localhost:9445 (GPU监控)"
|
| 129 |
+
echo "日志: docker-compose -f docker-compose.gpu.yml logs -f"
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
# Python直接部署
|
| 133 |
+
deploy_python() {
|
| 134 |
+
echo "🐍 使用Python直接部署..."
|
| 135 |
+
|
| 136 |
+
# 检查Python
|
| 137 |
+
if ! command -v python3 &> /dev/null; then
|
| 138 |
+
echo "安装Python3..."
|
| 139 |
+
sudo apt-get update
|
| 140 |
+
sudo apt-get install -y python3 python3-pip python3-venv
|
| 141 |
+
fi
|
| 142 |
+
|
| 143 |
+
# 创建虚拟环境
|
| 144 |
+
if [ ! -d "rag_env" ]; then
|
| 145 |
+
echo "创建Python虚拟环境..."
|
| 146 |
+
python3 -m venv rag_env
|
| 147 |
+
fi
|
| 148 |
+
|
| 149 |
+
# 激活虚拟环境并安装依赖
|
| 150 |
+
source rag_env/bin/activate
|
| 151 |
+
pip install --upgrade pip
|
| 152 |
+
pip install -r requirements_gpu.txt
|
| 153 |
+
|
| 154 |
+
# 安装Ollama
|
| 155 |
+
if ! command -v ollama &> /dev/null; then
|
| 156 |
+
echo "安装Ollama..."
|
| 157 |
+
curl -fsSL https://ollama.ai/install.sh | sh
|
| 158 |
+
fi
|
| 159 |
+
|
| 160 |
+
# 启动Ollama服务
|
| 161 |
+
echo "启动Ollama服务..."
|
| 162 |
+
ollama serve &
|
| 163 |
+
sleep 5
|
| 164 |
+
|
| 165 |
+
# 下载模型
|
| 166 |
+
echo "下载Mistral模型..."
|
| 167 |
+
ollama pull mistral
|
| 168 |
+
|
| 169 |
+
# 创建启动脚本
|
| 170 |
+
cat > start_gpu.sh << 'EOF'
|
| 171 |
+
#!/bin/bash
|
| 172 |
+
export CUDA_VISIBLE_DEVICES=0
|
| 173 |
+
export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512
|
| 174 |
+
export TOKENIZERS_PARALLELISM=false
|
| 175 |
+
source rag_env/bin/activate
|
| 176 |
+
python main.py
|
| 177 |
+
EOF
|
| 178 |
+
chmod +x start_gpu.sh
|
| 179 |
+
|
| 180 |
+
echo -e "${GREEN}✅ Python部署完成!${NC}"
|
| 181 |
+
echo "启动命令: ./start_gpu.sh"
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
# 验证部署
|
| 185 |
+
verify_deployment() {
|
| 186 |
+
echo "🔍 验证部署..."
|
| 187 |
+
sleep 10
|
| 188 |
+
|
| 189 |
+
if curl -f http://localhost:8000/health 2>/dev/null; then
|
| 190 |
+
echo -e "${GREEN}✅ 服务运行正常${NC}"
|
| 191 |
+
else
|
| 192 |
+
echo -e "${YELLOW}⚠️ 服务可能还在启动中,请稍后检查${NC}"
|
| 193 |
+
fi
|
| 194 |
+
|
| 195 |
+
# 显示GPU使用情况
|
| 196 |
+
echo "📊 GPU状态:"
|
| 197 |
+
nvidia-smi --query-gpu=utilization.gpu,memory.used,memory.total --format=csv,noheader,nounits
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
# 显示部署信息
|
| 201 |
+
show_info() {
|
| 202 |
+
echo ""
|
| 203 |
+
echo "🎉 部署完成!"
|
| 204 |
+
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
| 205 |
+
echo "📊 服务地址:"
|
| 206 |
+
echo " - 主服务: http://localhost:8000"
|
| 207 |
+
echo " - Ollama: http://localhost:11434"
|
| 208 |
+
echo ""
|
| 209 |
+
echo "🔧 常用命令:"
|
| 210 |
+
echo " - 查看日志: docker-compose -f docker-compose.gpu.yml logs -f"
|
| 211 |
+
echo " - 重启服务: docker-compose -f docker-compose.gpu.yml restart"
|
| 212 |
+
echo " - 停止服务: docker-compose -f docker-compose.gpu.yml down"
|
| 213 |
+
echo " - GPU监控: watch -n 1 nvidia-smi"
|
| 214 |
+
echo ""
|
| 215 |
+
echo "📚 文档位置:"
|
| 216 |
+
echo " - 部署指南: DEPLOYMENT_GUIDE.md"
|
| 217 |
+
echo " - 快速开始: QUICKSTART.md"
|
| 218 |
+
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
# 主函数
|
| 222 |
+
main() {
|
| 223 |
+
echo "🤖 自适应RAG系统 GPU部署脚本"
|
| 224 |
+
echo "适用于: Linux + RTX 4090"
|
| 225 |
+
echo ""
|
| 226 |
+
|
| 227 |
+
check_gpu
|
| 228 |
+
check_cuda
|
| 229 |
+
install_docker
|
| 230 |
+
install_nvidia_docker
|
| 231 |
+
setup_env
|
| 232 |
+
choose_deployment
|
| 233 |
+
verify_deployment
|
| 234 |
+
show_info
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
# 运行主函数
|
| 238 |
+
main "$@"
|
| 239 |
+
|
| 240 |
+
reactive
|
docker-compose.gpu.yml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Docker Compose配置文件 - GPU部署
|
| 2 |
+
version: '3.8'
|
| 3 |
+
|
| 4 |
+
services:
|
| 5 |
+
adaptive-rag:
|
| 6 |
+
build:
|
| 7 |
+
context: .
|
| 8 |
+
dockerfile: Dockerfile.gpu
|
| 9 |
+
container_name: adaptive-rag-gpu
|
| 10 |
+
restart: unless-stopped
|
| 11 |
+
environment:
|
| 12 |
+
- CUDA_VISIBLE_DEVICES=0
|
| 13 |
+
- PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512
|
| 14 |
+
- TOKENIZERS_PARALLELISM=false
|
| 15 |
+
- HF_HOME=/app/models
|
| 16 |
+
- TRANSFORMERS_CACHE=/app/models
|
| 17 |
+
env_file:
|
| 18 |
+
- .env
|
| 19 |
+
ports:
|
| 20 |
+
- "8000:8000"
|
| 21 |
+
- "8001:8001" # 可选:监控端口
|
| 22 |
+
volumes:
|
| 23 |
+
- ./data:/app/data
|
| 24 |
+
- ./models:/app/models
|
| 25 |
+
- ./logs:/app/logs
|
| 26 |
+
deploy:
|
| 27 |
+
resources:
|
| 28 |
+
reservations:
|
| 29 |
+
devices:
|
| 30 |
+
- driver: nvidia
|
| 31 |
+
count: 1
|
| 32 |
+
capabilities: [gpu]
|
| 33 |
+
depends_on:
|
| 34 |
+
- ollama
|
| 35 |
+
|
| 36 |
+
ollama:
|
| 37 |
+
image: ollama/ollama:latest
|
| 38 |
+
container_name: ollama-gpu
|
| 39 |
+
restart: unless-stopped
|
| 40 |
+
ports:
|
| 41 |
+
- "11434:11434"
|
| 42 |
+
volumes:
|
| 43 |
+
- ollama-data:/root/.ollama
|
| 44 |
+
deploy:
|
| 45 |
+
resources:
|
| 46 |
+
reservations:
|
| 47 |
+
devices:
|
| 48 |
+
- driver: nvidia
|
| 49 |
+
count: 1
|
| 50 |
+
capabilities: [gpu]
|
| 51 |
+
command: ["ollama", "serve"]
|
| 52 |
+
|
| 53 |
+
# 可选:监控服务
|
| 54 |
+
nvidia-smi-exporter:
|
| 55 |
+
image: mindprince/nvidia_gpu_prometheus_exporter:0.1
|
| 56 |
+
container_name: gpu-monitor
|
| 57 |
+
restart: unless-stopped
|
| 58 |
+
ports:
|
| 59 |
+
- "9445:9445"
|
| 60 |
+
deploy:
|
| 61 |
+
resources:
|
| 62 |
+
reservations:
|
| 63 |
+
devices:
|
| 64 |
+
- driver: nvidia
|
| 65 |
+
count: 1
|
| 66 |
+
capabilities: [gpu]
|
| 67 |
+
|
| 68 |
+
volumes:
|
| 69 |
+
ollama-data:
|
| 70 |
+
driver: local
|
document_processor.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
文档处理和向量化模块
|
| 3 |
+
负责文档加载、文本分块、向量化和向量数据库初始化
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 7 |
+
from langchain_community.document_loaders import WebBaseLoader
|
| 8 |
+
from langchain_community.vectorstores import Chroma
|
| 9 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 10 |
+
|
| 11 |
+
from config import (
|
| 12 |
+
KNOWLEDGE_BASE_URLS,
|
| 13 |
+
CHUNK_SIZE,
|
| 14 |
+
CHUNK_OVERLAP,
|
| 15 |
+
COLLECTION_NAME,
|
| 16 |
+
EMBEDDING_MODEL
|
| 17 |
+
)
|
| 18 |
+
from reranker import create_reranker
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DocumentProcessor:
|
| 22 |
+
"""文档处理器类,负责文档加载、处理和向量化"""
|
| 23 |
+
|
| 24 |
+
def __init__(self):
|
| 25 |
+
self.text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
| 26 |
+
chunk_size=CHUNK_SIZE,
|
| 27 |
+
chunk_overlap=CHUNK_OVERLAP
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# Try to initialize embeddings with error handling
|
| 31 |
+
try:
|
| 32 |
+
import torch
|
| 33 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 34 |
+
print(f"✅ 检测到设备: {device}")
|
| 35 |
+
if device == 'cuda':
|
| 36 |
+
print(f" GPU型号: {torch.cuda.get_device_name(0)}")
|
| 37 |
+
print(f" GPU内存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB")
|
| 38 |
+
|
| 39 |
+
self.embeddings = HuggingFaceEmbeddings(
|
| 40 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2", # 轻量级嵌入模型
|
| 41 |
+
model_kwargs={'device': device}, # 自动选择GPU或CPU
|
| 42 |
+
encode_kwargs={'normalize_embeddings': True} # 标准化嵌入向量
|
| 43 |
+
)
|
| 44 |
+
print(f"✅ HuggingFace嵌入模型初始化成功 (设备: {device})")
|
| 45 |
+
except Exception as e:
|
| 46 |
+
print(f"⚠️ HuggingFace嵌入初始化失败: {e}")
|
| 47 |
+
print("正在尝试备用嵌入方案...")
|
| 48 |
+
# Fallback to OpenAI embeddings or other alternatives
|
| 49 |
+
from langchain_community.embeddings import FakeEmbeddings
|
| 50 |
+
self.embeddings = FakeEmbeddings(size=384) # For testing purposes
|
| 51 |
+
print("✅ 使用测试嵌入模型")
|
| 52 |
+
|
| 53 |
+
self.vectorstore = None
|
| 54 |
+
self.retriever = None
|
| 55 |
+
|
| 56 |
+
# 初始化重排器
|
| 57 |
+
self.reranker = None
|
| 58 |
+
self._setup_reranker()
|
| 59 |
+
|
| 60 |
+
def _setup_reranker(self):
|
| 61 |
+
"""设置重排器"""
|
| 62 |
+
try:
|
| 63 |
+
# 使用混合重排器获得最佳效果
|
| 64 |
+
self.reranker = create_reranker('hybrid', self.embeddings)
|
| 65 |
+
print("✅ 重排器初始化成功")
|
| 66 |
+
except Exception as e:
|
| 67 |
+
print(f"⚠️ 重排器初始化失败: {e}")
|
| 68 |
+
print("将使用基础检索,不进行重排")
|
| 69 |
+
|
| 70 |
+
def load_documents(self, urls=None):
|
| 71 |
+
"""从URL加载文档"""
|
| 72 |
+
if urls is None:
|
| 73 |
+
urls = KNOWLEDGE_BASE_URLS
|
| 74 |
+
|
| 75 |
+
print(f"正在加载 {len(urls)} 个URL的文档...")
|
| 76 |
+
docs = [WebBaseLoader(url).load() for url in urls]
|
| 77 |
+
docs_list = [item for sublist in docs for item in sublist]
|
| 78 |
+
print(f"成功加载 {len(docs_list)} 个文档")
|
| 79 |
+
return docs_list
|
| 80 |
+
|
| 81 |
+
def split_documents(self, docs):
|
| 82 |
+
"""将文档分割成块"""
|
| 83 |
+
print("正在分割文档...")
|
| 84 |
+
doc_splits = self.text_splitter.split_documents(docs)
|
| 85 |
+
print(f"文档分割完成,共 {len(doc_splits)} 个文档块")
|
| 86 |
+
return doc_splits
|
| 87 |
+
|
| 88 |
+
def create_vectorstore(self, doc_splits):
|
| 89 |
+
"""创建向量数据库"""
|
| 90 |
+
print("正在创建向量数据库...")
|
| 91 |
+
self.vectorstore = Chroma.from_documents(
|
| 92 |
+
documents=doc_splits,
|
| 93 |
+
collection_name=COLLECTION_NAME,
|
| 94 |
+
embedding=self.embeddings,
|
| 95 |
+
)
|
| 96 |
+
self.retriever = self.vectorstore.as_retriever()
|
| 97 |
+
print("向量数据库创建完成")
|
| 98 |
+
return self.vectorstore, self.retriever
|
| 99 |
+
|
| 100 |
+
def setup_knowledge_base(self, urls=None, enable_graphrag=False):
|
| 101 |
+
"""设置完整的知识库(加载、分割、向量化)
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
urls: 文档URL列表
|
| 105 |
+
enable_graphrag: 是否启用GraphRAG索引
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
vectorstore, retriever, doc_splits
|
| 109 |
+
"""
|
| 110 |
+
docs = self.load_documents(urls)
|
| 111 |
+
doc_splits = self.split_documents(docs)
|
| 112 |
+
vectorstore, retriever = self.create_vectorstore(doc_splits)
|
| 113 |
+
|
| 114 |
+
# 返回doc_splits用于GraphRAG索引
|
| 115 |
+
return vectorstore, retriever, doc_splits
|
| 116 |
+
|
| 117 |
+
def enhanced_retrieve(self, query: str, top_k: int = 5, rerank_candidates: int = 20):
|
| 118 |
+
"""增强检索:先检索更多候选,然后重排"""
|
| 119 |
+
if not self.retriever:
|
| 120 |
+
print("⚠️ 检索器未初始化")
|
| 121 |
+
return []
|
| 122 |
+
|
| 123 |
+
# 1. 初始检索:获取更多候选文档
|
| 124 |
+
initial_docs = self.retriever.get_relevant_documents(query)
|
| 125 |
+
|
| 126 |
+
# 获取更多候选(如果可能)
|
| 127 |
+
if hasattr(self.retriever, 'search_kwargs'):
|
| 128 |
+
# 修改检索参数以获取更多结果
|
| 129 |
+
original_k = self.retriever.search_kwargs.get('k', 4)
|
| 130 |
+
self.retriever.search_kwargs['k'] = min(rerank_candidates, len(initial_docs))
|
| 131 |
+
candidate_docs = self.retriever.get_relevant_documents(query)
|
| 132 |
+
self.retriever.search_kwargs['k'] = original_k # 恢复原设置
|
| 133 |
+
else:
|
| 134 |
+
candidate_docs = initial_docs
|
| 135 |
+
|
| 136 |
+
print(f"初始检索获得 {len(candidate_docs)} 个候选文档")
|
| 137 |
+
|
| 138 |
+
# 2. 重排(如果重排器可用)
|
| 139 |
+
if self.reranker and len(candidate_docs) > top_k:
|
| 140 |
+
try:
|
| 141 |
+
reranked_results = self.reranker.rerank(query, candidate_docs, top_k)
|
| 142 |
+
final_docs = [doc for doc, score in reranked_results]
|
| 143 |
+
scores = [score for doc, score in reranked_results]
|
| 144 |
+
|
| 145 |
+
print(f"重排后返回 {len(final_docs)} 个文档")
|
| 146 |
+
print(f"重排分数范围: {min(scores):.4f} - {max(scores):.4f}")
|
| 147 |
+
|
| 148 |
+
return final_docs
|
| 149 |
+
except Exception as e:
|
| 150 |
+
print(f"⚠️ 重排失败: {e},使用原始检索结果")
|
| 151 |
+
return candidate_docs[:top_k]
|
| 152 |
+
else:
|
| 153 |
+
# 不重排或候选数量不足
|
| 154 |
+
return candidate_docs[:top_k]
|
| 155 |
+
|
| 156 |
+
def compare_retrieval_methods(self, query: str, top_k: int = 5):
|
| 157 |
+
"""比较不同检索方法的效果"""
|
| 158 |
+
if not self.retriever:
|
| 159 |
+
return {}
|
| 160 |
+
|
| 161 |
+
# 原始检索
|
| 162 |
+
original_docs = self.retriever.get_relevant_documents(query)[:top_k]
|
| 163 |
+
|
| 164 |
+
# 增强检索(带重排)
|
| 165 |
+
enhanced_docs = self.enhanced_retrieve(query, top_k)
|
| 166 |
+
|
| 167 |
+
return {
|
| 168 |
+
'query': query,
|
| 169 |
+
'original_retrieval': {
|
| 170 |
+
'count': len(original_docs),
|
| 171 |
+
'documents': [{
|
| 172 |
+
'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
|
| 173 |
+
'metadata': getattr(doc, 'metadata', {})
|
| 174 |
+
} for doc in original_docs]
|
| 175 |
+
},
|
| 176 |
+
'enhanced_retrieval': {
|
| 177 |
+
'count': len(enhanced_docs),
|
| 178 |
+
'documents': [{
|
| 179 |
+
'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
|
| 180 |
+
'metadata': getattr(doc, 'metadata', {})
|
| 181 |
+
} for doc in enhanced_docs]
|
| 182 |
+
},
|
| 183 |
+
'reranker_used': self.reranker is not None
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
def format_docs(self, docs):
|
| 187 |
+
"""格式化文档用于生成"""
|
| 188 |
+
return "\n\n".join(doc.page_content for doc in docs)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def initialize_document_processor():
|
| 192 |
+
"""初始化文档处理器并设置知识库"""
|
| 193 |
+
processor = DocumentProcessor()
|
| 194 |
+
vectorstore, retriever, doc_splits = processor.setup_knowledge_base()
|
| 195 |
+
return processor, vectorstore, retriever, doc_splits
|
entity_extractor.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
实体和关系提取模块
|
| 3 |
+
使用LLM从文档中提取实体、关系和属性,构建知识图谱的基础
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import List, Dict, Tuple
|
| 7 |
+
from langchain.prompts import PromptTemplate
|
| 8 |
+
from langchain_community.chat_models import ChatOllama
|
| 9 |
+
from langchain_core.output_parsers import JsonOutputParser
|
| 10 |
+
from config import LOCAL_LLM
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class EntityExtractor:
|
| 14 |
+
"""实体提取器 - 使用LLM从文本中提取实体"""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
|
| 18 |
+
|
| 19 |
+
# 实体提取提示模板
|
| 20 |
+
self.entity_prompt = PromptTemplate(
|
| 21 |
+
template="""你是一个专业的实体识别专家。从以下文本中提取所有重要的实体。
|
| 22 |
+
|
| 23 |
+
实体类型包括:
|
| 24 |
+
- PERSON: 人物、作者、研究者
|
| 25 |
+
- ORGANIZATION: 组织、机构、公司
|
| 26 |
+
- CONCEPT: 技术概念、算法、方法论
|
| 27 |
+
- TECHNOLOGY: 具体技术、工具、框架
|
| 28 |
+
- PAPER: 论文、出版物
|
| 29 |
+
- EVENT: 事件、会议
|
| 30 |
+
|
| 31 |
+
文本内容:
|
| 32 |
+
{text}
|
| 33 |
+
|
| 34 |
+
请以JSON格式返回,包含以下字段:
|
| 35 |
+
{{
|
| 36 |
+
"entities": [
|
| 37 |
+
{{
|
| 38 |
+
"name": "实体名称",
|
| 39 |
+
"type": "实体类型",
|
| 40 |
+
"description": "简短描述"
|
| 41 |
+
}}
|
| 42 |
+
]
|
| 43 |
+
}}
|
| 44 |
+
|
| 45 |
+
不要包含前言或解释,只返回JSON。
|
| 46 |
+
""",
|
| 47 |
+
input_variables=["text"]
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# 关系提取提示模板
|
| 51 |
+
self.relation_prompt = PromptTemplate(
|
| 52 |
+
template="""你是一个关系抽取专家。从文本中识别实体之间的关系。
|
| 53 |
+
|
| 54 |
+
已识别的实体:
|
| 55 |
+
{entities}
|
| 56 |
+
|
| 57 |
+
文本内容:
|
| 58 |
+
{text}
|
| 59 |
+
|
| 60 |
+
请识别实体之间的关系,以JSON格式返回:
|
| 61 |
+
{{
|
| 62 |
+
"relations": [
|
| 63 |
+
{{
|
| 64 |
+
"source": "源实体名称",
|
| 65 |
+
"target": "目标实体名称",
|
| 66 |
+
"relation_type": "关系类型",
|
| 67 |
+
"description": "关系描述"
|
| 68 |
+
}}
|
| 69 |
+
]
|
| 70 |
+
}}
|
| 71 |
+
|
| 72 |
+
关系类型包括: AUTHOR_OF, USES, BASED_ON, RELATED_TO, PART_OF, APPLIES_TO, IMPROVES, CITES
|
| 73 |
+
|
| 74 |
+
不要包含前言或解释,只返回JSON。
|
| 75 |
+
""",
|
| 76 |
+
input_variables=["text", "entities"]
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
self.entity_chain = self.entity_prompt | self.llm | JsonOutputParser()
|
| 80 |
+
self.relation_chain = self.relation_prompt | self.llm | JsonOutputParser()
|
| 81 |
+
|
| 82 |
+
def extract_entities(self, text: str) -> List[Dict]:
|
| 83 |
+
"""
|
| 84 |
+
从文本中提取实体
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
text: 输入文本
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
实体列表
|
| 91 |
+
"""
|
| 92 |
+
try:
|
| 93 |
+
result = self.entity_chain.invoke({"text": text[:2000]}) # 限制长度
|
| 94 |
+
entities = result.get("entities", [])
|
| 95 |
+
print(f"✅ 提取到 {len(entities)} 个实体")
|
| 96 |
+
return entities
|
| 97 |
+
except Exception as e:
|
| 98 |
+
print(f"❌ 实体提取失败: {e}")
|
| 99 |
+
return []
|
| 100 |
+
|
| 101 |
+
def extract_relations(self, text: str, entities: List[Dict]) -> List[Dict]:
|
| 102 |
+
"""
|
| 103 |
+
从文本中提取实体关系
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
text: 输入文本
|
| 107 |
+
entities: 已识别的实体列表
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
关系列表
|
| 111 |
+
"""
|
| 112 |
+
try:
|
| 113 |
+
entity_names = [e["name"] for e in entities]
|
| 114 |
+
result = self.relation_chain.invoke({
|
| 115 |
+
"text": text[:2000],
|
| 116 |
+
"entities": ", ".join(entity_names)
|
| 117 |
+
})
|
| 118 |
+
relations = result.get("relations", [])
|
| 119 |
+
print(f"✅ 提取到 {len(relations)} 个关系")
|
| 120 |
+
return relations
|
| 121 |
+
except Exception as e:
|
| 122 |
+
print(f"❌ 关系提取失败: {e}")
|
| 123 |
+
return []
|
| 124 |
+
|
| 125 |
+
def extract_from_document(self, document_text: str) -> Dict:
|
| 126 |
+
"""
|
| 127 |
+
从单个文档中提取实体和关系
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
document_text: 文档文本
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
包含实体和关系的字典
|
| 134 |
+
"""
|
| 135 |
+
print("🔍 开始提取实体...")
|
| 136 |
+
entities = self.extract_entities(document_text)
|
| 137 |
+
|
| 138 |
+
print("🔍 开始提取关系...")
|
| 139 |
+
relations = self.extract_relations(document_text, entities)
|
| 140 |
+
|
| 141 |
+
return {
|
| 142 |
+
"entities": entities,
|
| 143 |
+
"relations": relations
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class EntityDeduplicator:
|
| 148 |
+
"""实体去重和合并"""
|
| 149 |
+
|
| 150 |
+
def __init__(self):
|
| 151 |
+
self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
|
| 152 |
+
|
| 153 |
+
self.merge_prompt = PromptTemplate(
|
| 154 |
+
template="""判断以下两个实体是否指向同一个对象:
|
| 155 |
+
|
| 156 |
+
实体1: {entity1_name} - {entity1_desc}
|
| 157 |
+
实体2: {entity2_name} - {entity2_desc}
|
| 158 |
+
|
| 159 |
+
如果是同一个对象,返回:
|
| 160 |
+
{{
|
| 161 |
+
"is_same": true,
|
| 162 |
+
"canonical_name": "标准名称",
|
| 163 |
+
"reason": "原因"
|
| 164 |
+
}}
|
| 165 |
+
|
| 166 |
+
如果不是,返回:
|
| 167 |
+
{{
|
| 168 |
+
"is_same": false,
|
| 169 |
+
"reason": "原因"
|
| 170 |
+
}}
|
| 171 |
+
|
| 172 |
+
只返回JSON,不要其他内容。
|
| 173 |
+
""",
|
| 174 |
+
input_variables=["entity1_name", "entity1_desc", "entity2_name", "entity2_desc"]
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
self.merge_chain = self.merge_prompt | self.llm | JsonOutputParser()
|
| 178 |
+
|
| 179 |
+
def deduplicate_entities(self, entities: List[Dict]) -> Dict:
|
| 180 |
+
"""
|
| 181 |
+
去重实体列表
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
entities: 实体列表
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
包含entities和mapping的字典
|
| 188 |
+
"""
|
| 189 |
+
if len(entities) <= 1:
|
| 190 |
+
# 返回字典格式,保持一致性
|
| 191 |
+
entity_mapping = {entity["name"]: entity["name"] for entity in entities} if entities else {}
|
| 192 |
+
return {
|
| 193 |
+
"entities": entities,
|
| 194 |
+
"mapping": entity_mapping
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
print(f"🔄 开始去重 {len(entities)} 个实体...")
|
| 198 |
+
|
| 199 |
+
# 简单的基于名称的去重
|
| 200 |
+
unique_entities = {}
|
| 201 |
+
entity_mapping = {} # 映射别名到标准名称
|
| 202 |
+
|
| 203 |
+
for entity in entities:
|
| 204 |
+
name = entity["name"].lower().strip()
|
| 205 |
+
|
| 206 |
+
# 查找是否有相似实体
|
| 207 |
+
merged = False
|
| 208 |
+
for canonical_name, canonical_entity in unique_entities.items():
|
| 209 |
+
# 简单的字符串匹配(可以用LLM做更智能的判断)
|
| 210 |
+
if name in canonical_name or canonical_name in name:
|
| 211 |
+
entity_mapping[entity["name"]] = canonical_name
|
| 212 |
+
merged = True
|
| 213 |
+
break
|
| 214 |
+
|
| 215 |
+
if not merged:
|
| 216 |
+
unique_entities[name] = entity
|
| 217 |
+
entity_mapping[entity["name"]] = name
|
| 218 |
+
|
| 219 |
+
print(f"✅ 去重完成,剩余 {len(unique_entities)} 个唯一实体")
|
| 220 |
+
|
| 221 |
+
return {
|
| 222 |
+
"entities": list(unique_entities.values()),
|
| 223 |
+
"mapping": entity_mapping
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def initialize_entity_extractor():
|
| 228 |
+
"""初始化实体提取器"""
|
| 229 |
+
return EntityExtractor()
|
graph_indexer.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GraphRAG索引器
|
| 3 |
+
负责构建层次化的知识图谱索引,包括实体提取、图谱构建、社区检测和摘要生成
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import List, Dict
|
| 7 |
+
from langchain.schema import Document
|
| 8 |
+
|
| 9 |
+
from entity_extractor import EntityExtractor, EntityDeduplicator
|
| 10 |
+
from knowledge_graph import KnowledgeGraph, CommunitySummarizer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class GraphRAGIndexer:
|
| 14 |
+
"""GraphRAG索引器 - 实现Microsoft GraphRAG的索引流程"""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
print("🚀 初始化GraphRAG索引器...")
|
| 18 |
+
|
| 19 |
+
self.entity_extractor = EntityExtractor()
|
| 20 |
+
self.entity_deduplicator = EntityDeduplicator()
|
| 21 |
+
self.knowledge_graph = KnowledgeGraph()
|
| 22 |
+
self.community_summarizer = CommunitySummarizer()
|
| 23 |
+
|
| 24 |
+
self.indexed = False
|
| 25 |
+
|
| 26 |
+
print("✅ GraphRAG索引器初始化完成")
|
| 27 |
+
|
| 28 |
+
def index_documents(self, documents: List[Document],
|
| 29 |
+
batch_size: int = 10,
|
| 30 |
+
save_path: str = None) -> KnowledgeGraph:
|
| 31 |
+
"""
|
| 32 |
+
对文档集合建立GraphRAG索引
|
| 33 |
+
|
| 34 |
+
工作流程(遵循Microsoft GraphRAG):
|
| 35 |
+
1. 文档分块(已在document_processor中完成)
|
| 36 |
+
2. 实体和关系提取
|
| 37 |
+
3. 实体去重和合并
|
| 38 |
+
4. 构建知识图谱
|
| 39 |
+
5. 社区检测
|
| 40 |
+
6. 生成社区摘要
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
documents: 文档列表
|
| 44 |
+
batch_size: 批处理大小
|
| 45 |
+
save_path: 保存路径
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
构建好的知识图谱
|
| 49 |
+
"""
|
| 50 |
+
print(f"\n{'='*50}")
|
| 51 |
+
print(f"📊 开始GraphRAG索引流程")
|
| 52 |
+
print(f" 文档数量: {len(documents)}")
|
| 53 |
+
print(f"{'='*50}\n")
|
| 54 |
+
|
| 55 |
+
# 步骤1: 实体和关系提取
|
| 56 |
+
print("📍 步骤 1/5: 实体和关系提取")
|
| 57 |
+
extraction_results = []
|
| 58 |
+
|
| 59 |
+
for i in range(0, len(documents), batch_size):
|
| 60 |
+
batch = documents[i:i+batch_size]
|
| 61 |
+
print(f" 处理批次 {i//batch_size + 1}/{(len(documents)-1)//batch_size + 1}...")
|
| 62 |
+
|
| 63 |
+
for doc in batch:
|
| 64 |
+
result = self.entity_extractor.extract_from_document(doc.page_content)
|
| 65 |
+
extraction_results.append(result)
|
| 66 |
+
|
| 67 |
+
# 步骤2: 实体去重
|
| 68 |
+
print("\n📍 步骤 2/5: 实体去重和合并")
|
| 69 |
+
all_entities = []
|
| 70 |
+
all_relations = []
|
| 71 |
+
|
| 72 |
+
for result in extraction_results:
|
| 73 |
+
all_entities.extend(result.get("entities", []))
|
| 74 |
+
all_relations.extend(result.get("relations", []))
|
| 75 |
+
|
| 76 |
+
dedup_result = self.entity_deduplicator.deduplicate_entities(all_entities)
|
| 77 |
+
unique_entities = dedup_result["entities"]
|
| 78 |
+
entity_mapping = dedup_result["mapping"]
|
| 79 |
+
|
| 80 |
+
# 更新关系中的实体名称
|
| 81 |
+
mapped_relations = []
|
| 82 |
+
for relation in all_relations:
|
| 83 |
+
source = entity_mapping.get(relation["source"], relation["source"])
|
| 84 |
+
target = entity_mapping.get(relation["target"], relation["target"])
|
| 85 |
+
mapped_relations.append({
|
| 86 |
+
**relation,
|
| 87 |
+
"source": source,
|
| 88 |
+
"target": target
|
| 89 |
+
})
|
| 90 |
+
|
| 91 |
+
# 步骤3: 构建知识图谱
|
| 92 |
+
print("\n📍 步骤 3/5: 构建知识图谱")
|
| 93 |
+
cleaned_results = [{
|
| 94 |
+
"entities": unique_entities,
|
| 95 |
+
"relations": mapped_relations
|
| 96 |
+
}]
|
| 97 |
+
self.knowledge_graph.build_from_extractions(cleaned_results)
|
| 98 |
+
|
| 99 |
+
# 步骤4: 社区检测
|
| 100 |
+
print("\n📍 步骤 4/5: 社区检测")
|
| 101 |
+
self.knowledge_graph.detect_communities(algorithm="louvain")
|
| 102 |
+
|
| 103 |
+
# 步骤5: 生成社区摘要
|
| 104 |
+
print("\n📍 步骤 5/5: 生成社区摘要")
|
| 105 |
+
self.community_summarizer.summarize_all_communities(self.knowledge_graph)
|
| 106 |
+
|
| 107 |
+
# 保存图谱
|
| 108 |
+
if save_path:
|
| 109 |
+
self.knowledge_graph.save_to_file(save_path)
|
| 110 |
+
|
| 111 |
+
self.indexed = True
|
| 112 |
+
|
| 113 |
+
# 打印统计信息
|
| 114 |
+
print(f"\n{'='*50}")
|
| 115 |
+
print("✅ GraphRAG索引构建完成!")
|
| 116 |
+
stats = self.knowledge_graph.get_statistics()
|
| 117 |
+
print(f"\n📊 统计信息:")
|
| 118 |
+
print(f" - 节点数: {stats['num_nodes']}")
|
| 119 |
+
print(f" - 边数: {stats['num_edges']}")
|
| 120 |
+
print(f" - 社区数: {stats['num_communities']}")
|
| 121 |
+
print(f" - 图密度: {stats['density']:.4f}")
|
| 122 |
+
print(f"\n 实体类型分布:")
|
| 123 |
+
for etype, count in stats['entity_types'].items():
|
| 124 |
+
print(f" • {etype}: {count}")
|
| 125 |
+
print(f"{'='*50}\n")
|
| 126 |
+
|
| 127 |
+
return self.knowledge_graph
|
| 128 |
+
|
| 129 |
+
def get_graph(self) -> KnowledgeGraph:
|
| 130 |
+
"""获取知识图谱"""
|
| 131 |
+
if not self.indexed:
|
| 132 |
+
print("⚠️ 图谱尚未构建,请先调用 index_documents()")
|
| 133 |
+
return self.knowledge_graph
|
| 134 |
+
|
| 135 |
+
def load_index(self, filepath: str) -> KnowledgeGraph:
|
| 136 |
+
"""加载已有的图谱索引"""
|
| 137 |
+
print(f"📂 从文件加载图谱索引: {filepath}")
|
| 138 |
+
self.knowledge_graph.load_from_file(filepath)
|
| 139 |
+
self.indexed = True
|
| 140 |
+
return self.knowledge_graph
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def initialize_graph_indexer():
|
| 144 |
+
"""初始化GraphRAG索引器"""
|
| 145 |
+
return GraphRAGIndexer()
|
graph_retriever.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GraphRAG检索器
|
| 3 |
+
实现基于知识图谱的检索策略,包括本地查询和全局查询
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import List, Dict, Set, Tuple
|
| 7 |
+
from langchain.prompts import PromptTemplate
|
| 8 |
+
from langchain_community.chat_models import ChatOllama
|
| 9 |
+
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
|
| 10 |
+
|
| 11 |
+
from knowledge_graph import KnowledgeGraph
|
| 12 |
+
from config import LOCAL_LLM
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class GraphRetriever:
|
| 16 |
+
"""基于知识图谱的检索器"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, knowledge_graph: KnowledgeGraph):
|
| 19 |
+
self.kg = knowledge_graph
|
| 20 |
+
self.llm = ChatOllama(model=LOCAL_LLM, temperature=0.3)
|
| 21 |
+
|
| 22 |
+
# 实体识别提示
|
| 23 |
+
self.entity_recognition_prompt = PromptTemplate(
|
| 24 |
+
template="""从以下问题中识别关键实体和概念:
|
| 25 |
+
|
| 26 |
+
问题: {question}
|
| 27 |
+
|
| 28 |
+
已知实体示例: {sample_entities}
|
| 29 |
+
|
| 30 |
+
请识别问题中提到的实体,返回JSON格式:
|
| 31 |
+
{{
|
| 32 |
+
"entities": ["实体1", "实体2", ...]
|
| 33 |
+
}}
|
| 34 |
+
|
| 35 |
+
只返回JSON,不要其他内容。
|
| 36 |
+
""",
|
| 37 |
+
input_variables=["question", "sample_entities"]
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# 全局查询生成提示
|
| 41 |
+
self.global_query_prompt = PromptTemplate(
|
| 42 |
+
template="""你是一个知识图谱分析专家。基于以下社区摘要,回答用户问题。
|
| 43 |
+
|
| 44 |
+
用户问题: {question}
|
| 45 |
+
|
| 46 |
+
相关社区摘要:
|
| 47 |
+
{community_summaries}
|
| 48 |
+
|
| 49 |
+
请基于这些摘要提供一个综合性的答案。如果摘要中没有相关信息,请说明。
|
| 50 |
+
|
| 51 |
+
答案:
|
| 52 |
+
""",
|
| 53 |
+
input_variables=["question", "community_summaries"]
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# 本地查询生成提示
|
| 57 |
+
self.local_query_prompt = PromptTemplate(
|
| 58 |
+
template="""基于以下实体及其关系信息,回答用户问题。
|
| 59 |
+
|
| 60 |
+
用户问题: {question}
|
| 61 |
+
|
| 62 |
+
相关实体信息:
|
| 63 |
+
{entity_info}
|
| 64 |
+
|
| 65 |
+
实体间的关系:
|
| 66 |
+
{relations}
|
| 67 |
+
|
| 68 |
+
请基于这些信息提供答案。
|
| 69 |
+
|
| 70 |
+
答案:
|
| 71 |
+
""",
|
| 72 |
+
input_variables=["question", "entity_info", "relations"]
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
self.entity_recognition_chain = self.entity_recognition_prompt | self.llm | JsonOutputParser()
|
| 76 |
+
self.global_query_chain = self.global_query_prompt | self.llm | StrOutputParser()
|
| 77 |
+
self.local_query_chain = self.local_query_prompt | self.llm | StrOutputParser()
|
| 78 |
+
|
| 79 |
+
def recognize_entities(self, question: str) -> List[str]:
|
| 80 |
+
"""
|
| 81 |
+
从问题中识别实体
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
question: 用户问题
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
识别到的实体列表
|
| 88 |
+
"""
|
| 89 |
+
# 获取一些示例实体
|
| 90 |
+
sample_entities = list(self.kg.entities.keys())[:10]
|
| 91 |
+
sample_text = ", ".join(sample_entities)
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
result = self.entity_recognition_chain.invoke({
|
| 95 |
+
"question": question,
|
| 96 |
+
"sample_entities": sample_text
|
| 97 |
+
})
|
| 98 |
+
entities = result.get("entities", [])
|
| 99 |
+
|
| 100 |
+
# 匹配到图谱中的实体
|
| 101 |
+
matched_entities = []
|
| 102 |
+
for entity in entities:
|
| 103 |
+
# 精确匹配
|
| 104 |
+
if entity in self.kg.entities:
|
| 105 |
+
matched_entities.append(entity)
|
| 106 |
+
else:
|
| 107 |
+
# 模糊匹配
|
| 108 |
+
for kg_entity in self.kg.entities.keys():
|
| 109 |
+
if entity.lower() in kg_entity.lower() or kg_entity.lower() in entity.lower():
|
| 110 |
+
matched_entities.append(kg_entity)
|
| 111 |
+
break
|
| 112 |
+
|
| 113 |
+
print(f"🔍 识别到实体: {matched_entities}")
|
| 114 |
+
return matched_entities
|
| 115 |
+
|
| 116 |
+
except Exception as e:
|
| 117 |
+
print(f"❌ 实体识别失败: {e}")
|
| 118 |
+
return []
|
| 119 |
+
|
| 120 |
+
def local_query(self, question: str, max_hops: int = 2, top_k: int = 10) -> str:
|
| 121 |
+
"""
|
| 122 |
+
本地查询 - 基于问题中的实体及其邻域进行检索
|
| 123 |
+
|
| 124 |
+
适用场景: 针对特定实体的详细问题
|
| 125 |
+
例如: "AlphaCodium的作者是谁?"
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
question: 用户问题
|
| 129 |
+
max_hops: 最大跳数
|
| 130 |
+
top_k: 返回的最大实体数
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
答案文本
|
| 134 |
+
"""
|
| 135 |
+
print(f"\n🔎 执行本地查询...")
|
| 136 |
+
|
| 137 |
+
# 1. 识别问题中的实体
|
| 138 |
+
mentioned_entities = self.recognize_entities(question)
|
| 139 |
+
|
| 140 |
+
if not mentioned_entities:
|
| 141 |
+
return "未能在知识图谱中找到相关实体。"
|
| 142 |
+
|
| 143 |
+
# 2. 获取实体的邻域
|
| 144 |
+
relevant_entities = set()
|
| 145 |
+
for entity in mentioned_entities:
|
| 146 |
+
neighbors = self.kg.get_node_neighbors(entity, depth=max_hops)
|
| 147 |
+
relevant_entities.update(neighbors)
|
| 148 |
+
|
| 149 |
+
relevant_entities = list(relevant_entities)[:top_k]
|
| 150 |
+
|
| 151 |
+
# 3. 收集实体信息
|
| 152 |
+
entity_info_list = []
|
| 153 |
+
for entity in relevant_entities:
|
| 154 |
+
info = self.kg.get_entity_info(entity)
|
| 155 |
+
if info:
|
| 156 |
+
entity_info_list.append(
|
| 157 |
+
f"- {info['name']} ({info.get('type', 'UNKNOWN')}): {info.get('description', '无描述')}"
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# 4. 收集关系信息
|
| 161 |
+
relation_list = []
|
| 162 |
+
for u, v, data in self.kg.graph.edges(data=True):
|
| 163 |
+
if u in relevant_entities and v in relevant_entities:
|
| 164 |
+
relation_list.append(
|
| 165 |
+
f"- {u} --[{data.get('relation_type', 'RELATED')}]--> {v}: {data.get('description', '')}"
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
entity_info_text = "\n".join(entity_info_list) if entity_info_list else "无相关实体信息"
|
| 169 |
+
relations_text = "\n".join(relation_list[:20]) if relation_list else "无相关关系"
|
| 170 |
+
|
| 171 |
+
# 5. 生成答案
|
| 172 |
+
try:
|
| 173 |
+
answer = self.local_query_chain.invoke({
|
| 174 |
+
"question": question,
|
| 175 |
+
"entity_info": entity_info_text,
|
| 176 |
+
"relations": relations_text
|
| 177 |
+
})
|
| 178 |
+
print(f"✅ 本地查询完成")
|
| 179 |
+
return answer.strip()
|
| 180 |
+
except Exception as e:
|
| 181 |
+
print(f"❌ 本地查询失败: {e}")
|
| 182 |
+
return "查询失败,请重试。"
|
| 183 |
+
|
| 184 |
+
def global_query(self, question: str, top_k_communities: int = 5) -> str:
|
| 185 |
+
"""
|
| 186 |
+
全局查询 - 基于社区摘要进行高层次查询
|
| 187 |
+
|
| 188 |
+
适用场景: 需要整体理解的概括性问题
|
| 189 |
+
例如: "这些文档主要讨论什么主题?"
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
question: 用户问题
|
| 193 |
+
top_k_communities: 使用的社区数量
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
答案文本
|
| 197 |
+
"""
|
| 198 |
+
print(f"\n🌍 执行全局查询...")
|
| 199 |
+
|
| 200 |
+
if not self.kg.community_summaries:
|
| 201 |
+
return "知识图谱尚未生成社区摘要,请先运行索引流程。"
|
| 202 |
+
|
| 203 |
+
# 获取社区摘要
|
| 204 |
+
community_summaries = []
|
| 205 |
+
for cid, summary in list(self.kg.community_summaries.items())[:top_k_communities]:
|
| 206 |
+
community_summaries.append(f"社区 {cid}:\n{summary}\n")
|
| 207 |
+
|
| 208 |
+
summaries_text = "\n".join(community_summaries)
|
| 209 |
+
|
| 210 |
+
# 生成答案
|
| 211 |
+
try:
|
| 212 |
+
answer = self.global_query_chain.invoke({
|
| 213 |
+
"question": question,
|
| 214 |
+
"community_summaries": summaries_text
|
| 215 |
+
})
|
| 216 |
+
print(f"✅ 全局查询完成")
|
| 217 |
+
return answer.strip()
|
| 218 |
+
except Exception as e:
|
| 219 |
+
print(f"❌ 全局查询失败: {e}")
|
| 220 |
+
return "查询失败,请重试。"
|
| 221 |
+
|
| 222 |
+
def hybrid_query(self, question: str) -> Dict[str, str]:
|
| 223 |
+
"""
|
| 224 |
+
混合查询 - 同时执行本地和全局查询,返回两种结果
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
question: 用户问题
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
包含本地和全局查询结果的字典
|
| 231 |
+
"""
|
| 232 |
+
print(f"\n🔀 执行混合查询...")
|
| 233 |
+
|
| 234 |
+
local_answer = self.local_query(question)
|
| 235 |
+
global_answer = self.global_query(question)
|
| 236 |
+
|
| 237 |
+
return {
|
| 238 |
+
"local": local_answer,
|
| 239 |
+
"global": global_answer,
|
| 240 |
+
"question": question
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
def smart_query(self, question: str) -> str:
|
| 244 |
+
"""
|
| 245 |
+
智能查询 - 根据问题类型自动选择查询策略
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
question: 用户问题
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
答案文本
|
| 252 |
+
"""
|
| 253 |
+
# 判断问题类型
|
| 254 |
+
question_lower = question.lower()
|
| 255 |
+
|
| 256 |
+
# 包含具体实体名称的问题 -> 本地查询
|
| 257 |
+
mentioned_entities = self.recognize_entities(question)
|
| 258 |
+
if mentioned_entities:
|
| 259 |
+
print("📍 检测到具体实体,使用本地查询")
|
| 260 |
+
return self.local_query(question)
|
| 261 |
+
|
| 262 |
+
# 概括性问题 -> 全局查询
|
| 263 |
+
global_keywords = ["主要", "总体", "概述", "整体", "主题", "讨论", "内容", "what", "overview", "main", "topics"]
|
| 264 |
+
if any(keyword in question_lower for keyword in global_keywords):
|
| 265 |
+
print("🌐 检测到概括性问题,使用全局查询")
|
| 266 |
+
return self.global_query(question)
|
| 267 |
+
|
| 268 |
+
# 默认使用本地查询
|
| 269 |
+
print("📍 使用本地查询作为默认策略")
|
| 270 |
+
return self.local_query(question)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def initialize_graph_retriever(knowledge_graph: KnowledgeGraph):
|
| 274 |
+
"""初始化GraphRAG检索器"""
|
| 275 |
+
return GraphRetriever(knowledge_graph)
|
knowledge_graph.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
知识图谱模块
|
| 3 |
+
实现GraphRAG的核心功能:图谱构建、社区检测、层次化摘要
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import networkx as nx
|
| 7 |
+
from typing import List, Dict, Set, Tuple, Optional
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from community import community_louvain # python-louvain
|
| 13 |
+
LOUVAIN_AVAILABLE = True
|
| 14 |
+
except ImportError:
|
| 15 |
+
LOUVAIN_AVAILABLE = False
|
| 16 |
+
print("⚠️ python-louvain未安装,社区检测功能受限")
|
| 17 |
+
|
| 18 |
+
from langchain.prompts import PromptTemplate
|
| 19 |
+
from langchain_community.chat_models import ChatOllama
|
| 20 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 21 |
+
from config import LOCAL_LLM
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class KnowledgeGraph:
|
| 25 |
+
"""知识图谱类 - 使用NetworkX构建和管理图谱"""
|
| 26 |
+
|
| 27 |
+
def __init__(self):
|
| 28 |
+
self.graph = nx.Graph() # 无向图
|
| 29 |
+
self.entities = {} # 实体详细信息
|
| 30 |
+
self.communities = {} # 社区划分结果
|
| 31 |
+
self.community_summaries = {} # 社区摘要
|
| 32 |
+
|
| 33 |
+
def add_entity(self, name: str, entity_type: str, description: str = "", **kwargs):
|
| 34 |
+
"""添加实体节点"""
|
| 35 |
+
self.graph.add_node(
|
| 36 |
+
name,
|
| 37 |
+
type=entity_type,
|
| 38 |
+
description=description,
|
| 39 |
+
**kwargs
|
| 40 |
+
)
|
| 41 |
+
self.entities[name] = {
|
| 42 |
+
"name": name,
|
| 43 |
+
"type": entity_type,
|
| 44 |
+
"description": description,
|
| 45 |
+
**kwargs
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
def add_relation(self, source: str, target: str, relation_type: str,
|
| 49 |
+
description: str = "", weight: float = 1.0):
|
| 50 |
+
"""添加关系边"""
|
| 51 |
+
self.graph.add_edge(
|
| 52 |
+
source,
|
| 53 |
+
target,
|
| 54 |
+
relation_type=relation_type,
|
| 55 |
+
description=description,
|
| 56 |
+
weight=weight
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def build_from_extractions(self, extraction_results: List[Dict]):
|
| 60 |
+
"""
|
| 61 |
+
从实体提取结果构建图谱
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
extraction_results: 实体和关系提取结果列表
|
| 65 |
+
"""
|
| 66 |
+
print("🔨 开始构建知识图谱...")
|
| 67 |
+
|
| 68 |
+
total_entities = 0
|
| 69 |
+
total_relations = 0
|
| 70 |
+
|
| 71 |
+
for result in extraction_results:
|
| 72 |
+
# 添加实体
|
| 73 |
+
entities = result.get("entities", [])
|
| 74 |
+
for entity in entities:
|
| 75 |
+
self.add_entity(
|
| 76 |
+
name=entity["name"],
|
| 77 |
+
entity_type=entity.get("type", "UNKNOWN"),
|
| 78 |
+
description=entity.get("description", "")
|
| 79 |
+
)
|
| 80 |
+
total_entities += 1
|
| 81 |
+
|
| 82 |
+
# 添加关系
|
| 83 |
+
relations = result.get("relations", [])
|
| 84 |
+
for relation in relations:
|
| 85 |
+
source = relation.get("source")
|
| 86 |
+
target = relation.get("target")
|
| 87 |
+
|
| 88 |
+
# 确保节点存在
|
| 89 |
+
if source in self.graph and target in self.graph:
|
| 90 |
+
self.add_relation(
|
| 91 |
+
source=source,
|
| 92 |
+
target=target,
|
| 93 |
+
relation_type=relation.get("relation_type", "RELATED_TO"),
|
| 94 |
+
description=relation.get("description", "")
|
| 95 |
+
)
|
| 96 |
+
total_relations += 1
|
| 97 |
+
|
| 98 |
+
print(f"✅ 图谱构建完成: {total_entities} 个实体, {total_relations} 个关系")
|
| 99 |
+
print(f" 实际节点数: {self.graph.number_of_nodes()}")
|
| 100 |
+
print(f" 实际边数: {self.graph.number_of_edges()}")
|
| 101 |
+
|
| 102 |
+
def detect_communities(self, algorithm: str = "louvain") -> Dict[str, int]:
|
| 103 |
+
"""
|
| 104 |
+
社区检测 - GraphRAG的核心组件
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
algorithm: 社区检测算法 ('louvain', 'greedy', 'label_propagation')
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
节点到社区ID的映射
|
| 111 |
+
"""
|
| 112 |
+
print(f"🔍 开始社区检测 (算法: {algorithm})...")
|
| 113 |
+
|
| 114 |
+
if self.graph.number_of_nodes() == 0:
|
| 115 |
+
print("⚠️ 图谱为空,跳过社区检测")
|
| 116 |
+
return {}
|
| 117 |
+
|
| 118 |
+
try:
|
| 119 |
+
if algorithm == "louvain" and LOUVAIN_AVAILABLE:
|
| 120 |
+
communities = community_louvain.best_partition(self.graph)
|
| 121 |
+
elif algorithm == "greedy":
|
| 122 |
+
communities_generator = nx.community.greedy_modularity_communities(self.graph)
|
| 123 |
+
communities = {}
|
| 124 |
+
for idx, community_set in enumerate(communities_generator):
|
| 125 |
+
for node in community_set:
|
| 126 |
+
communities[node] = idx
|
| 127 |
+
elif algorithm == "label_propagation":
|
| 128 |
+
communities_generator = nx.community.label_propagation_communities(self.graph)
|
| 129 |
+
communities = {}
|
| 130 |
+
for idx, community_set in enumerate(communities_generator):
|
| 131 |
+
for node in community_set:
|
| 132 |
+
communities[node] = idx
|
| 133 |
+
else:
|
| 134 |
+
print(f"⚠️ 未知算法 {algorithm},使用贪婪算法")
|
| 135 |
+
communities_generator = nx.community.greedy_modularity_communities(self.graph)
|
| 136 |
+
communities = {}
|
| 137 |
+
for idx, community_set in enumerate(communities_generator):
|
| 138 |
+
for node in community_set:
|
| 139 |
+
communities[node] = idx
|
| 140 |
+
|
| 141 |
+
self.communities = communities
|
| 142 |
+
num_communities = len(set(communities.values()))
|
| 143 |
+
print(f"✅ 检测到 {num_communities} 个社区")
|
| 144 |
+
|
| 145 |
+
return communities
|
| 146 |
+
|
| 147 |
+
except Exception as e:
|
| 148 |
+
print(f"❌ 社区检测失败: {e}")
|
| 149 |
+
return {}
|
| 150 |
+
|
| 151 |
+
def get_community_members(self, community_id: int) -> List[str]:
|
| 152 |
+
"""获取指定社区的所有成员"""
|
| 153 |
+
return [node for node, cid in self.communities.items() if cid == community_id]
|
| 154 |
+
|
| 155 |
+
def get_community_subgraph(self, community_id: int) -> nx.Graph:
|
| 156 |
+
"""获取指定社区的子图"""
|
| 157 |
+
members = self.get_community_members(community_id)
|
| 158 |
+
return self.graph.subgraph(members)
|
| 159 |
+
|
| 160 |
+
def get_node_neighbors(self, node: str, depth: int = 1) -> Set[str]:
|
| 161 |
+
"""获取节点的邻居(支持多跳)"""
|
| 162 |
+
if node not in self.graph:
|
| 163 |
+
return set()
|
| 164 |
+
|
| 165 |
+
neighbors = {node}
|
| 166 |
+
current_layer = {node}
|
| 167 |
+
|
| 168 |
+
for _ in range(depth):
|
| 169 |
+
next_layer = set()
|
| 170 |
+
for n in current_layer:
|
| 171 |
+
next_layer.update(self.graph.neighbors(n))
|
| 172 |
+
neighbors.update(next_layer)
|
| 173 |
+
current_layer = next_layer
|
| 174 |
+
|
| 175 |
+
return neighbors
|
| 176 |
+
|
| 177 |
+
def get_entity_info(self, entity_name: str) -> Optional[Dict]:
|
| 178 |
+
"""获取实体详细信息"""
|
| 179 |
+
return self.entities.get(entity_name)
|
| 180 |
+
|
| 181 |
+
def search_entities_by_type(self, entity_type: str) -> List[str]:
|
| 182 |
+
"""按类型搜索实体"""
|
| 183 |
+
return [
|
| 184 |
+
name for name, data in self.entities.items()
|
| 185 |
+
if data.get("type") == entity_type
|
| 186 |
+
]
|
| 187 |
+
|
| 188 |
+
def get_statistics(self) -> Dict:
|
| 189 |
+
"""获取图谱统计信息"""
|
| 190 |
+
stats = {
|
| 191 |
+
"num_nodes": self.graph.number_of_nodes(),
|
| 192 |
+
"num_edges": self.graph.number_of_edges(),
|
| 193 |
+
"num_communities": len(set(self.communities.values())) if self.communities else 0,
|
| 194 |
+
"density": nx.density(self.graph) if self.graph.number_of_nodes() > 0 else 0,
|
| 195 |
+
"entity_types": {}
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
# 统计实体类型分布
|
| 199 |
+
for entity in self.entities.values():
|
| 200 |
+
etype = entity.get("type", "UNKNOWN")
|
| 201 |
+
stats["entity_types"][etype] = stats["entity_types"].get(etype, 0) + 1
|
| 202 |
+
|
| 203 |
+
return stats
|
| 204 |
+
|
| 205 |
+
def save_to_file(self, filepath: str):
|
| 206 |
+
"""保存图谱到文件"""
|
| 207 |
+
data = {
|
| 208 |
+
"entities": self.entities,
|
| 209 |
+
"edges": [
|
| 210 |
+
{
|
| 211 |
+
"source": u,
|
| 212 |
+
"target": v,
|
| 213 |
+
"data": data
|
| 214 |
+
}
|
| 215 |
+
for u, v, data in self.graph.edges(data=True)
|
| 216 |
+
],
|
| 217 |
+
"communities": self.communities,
|
| 218 |
+
"community_summaries": self.community_summaries
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
with open(filepath, 'w', encoding='utf-8') as f:
|
| 222 |
+
json.dump(data, f, ensure_ascii=False, indent=2)
|
| 223 |
+
|
| 224 |
+
print(f"✅ 图谱已保存到: {filepath}")
|
| 225 |
+
|
| 226 |
+
def load_from_file(self, filepath: str):
|
| 227 |
+
"""从文件加载图谱"""
|
| 228 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 229 |
+
data = json.load(f)
|
| 230 |
+
|
| 231 |
+
self.entities = data.get("entities", {})
|
| 232 |
+
self.communities = data.get("communities", {})
|
| 233 |
+
self.community_summaries = data.get("community_summaries", {})
|
| 234 |
+
|
| 235 |
+
# 重建图
|
| 236 |
+
self.graph.clear()
|
| 237 |
+
for name, entity in self.entities.items():
|
| 238 |
+
self.add_entity(**entity)
|
| 239 |
+
|
| 240 |
+
for edge in data.get("edges", []):
|
| 241 |
+
self.graph.add_edge(
|
| 242 |
+
edge["source"],
|
| 243 |
+
edge["target"],
|
| 244 |
+
**edge["data"]
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
print(f"✅ 图谱已从文件加载: {filepath}")
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class CommunitySummarizer:
|
| 251 |
+
"""社区摘要生成器 - GraphRAG的关键组件"""
|
| 252 |
+
|
| 253 |
+
def __init__(self):
|
| 254 |
+
self.llm = ChatOllama(model=LOCAL_LLM, temperature=0.3)
|
| 255 |
+
|
| 256 |
+
self.summary_prompt = PromptTemplate(
|
| 257 |
+
template="""你是一个知识图谱分析专家。请为以下社区生成一个综合摘要。
|
| 258 |
+
|
| 259 |
+
社区成员(实体):
|
| 260 |
+
{entities}
|
| 261 |
+
|
| 262 |
+
实体间的关系:
|
| 263 |
+
{relations}
|
| 264 |
+
|
| 265 |
+
请生成一个简洁的摘要,描述:
|
| 266 |
+
1. 这个社区的主题是什么
|
| 267 |
+
2. 主要包含哪些核心概念
|
| 268 |
+
3. 实体之间的关键关系
|
| 269 |
+
|
| 270 |
+
摘要(2-3句话):
|
| 271 |
+
""",
|
| 272 |
+
input_variables=["entities", "relations"]
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
self.summary_chain = self.summary_prompt | self.llm | StrOutputParser()
|
| 276 |
+
|
| 277 |
+
def summarize_community(self, kg: KnowledgeGraph, community_id: int) -> str:
|
| 278 |
+
"""
|
| 279 |
+
为指定社区生成摘要
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
kg: 知识图谱对象
|
| 283 |
+
community_id: 社区ID
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
社区摘要文本
|
| 287 |
+
"""
|
| 288 |
+
members = kg.get_community_members(community_id)
|
| 289 |
+
subgraph = kg.get_community_subgraph(community_id)
|
| 290 |
+
|
| 291 |
+
# 准备实体信息
|
| 292 |
+
entity_info = []
|
| 293 |
+
for member in members[:20]: # 限制数量
|
| 294 |
+
info = kg.get_entity_info(member)
|
| 295 |
+
if info:
|
| 296 |
+
entity_info.append(
|
| 297 |
+
f"- {info['name']} ({info.get('type', 'UNKNOWN')}): {info.get('description', '无描述')}"
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
# 准备关系信息
|
| 301 |
+
relation_info = []
|
| 302 |
+
for u, v, data in subgraph.edges(data=True):
|
| 303 |
+
relation_info.append(
|
| 304 |
+
f"- {u} --[{data.get('relation_type', 'RELATED')}]--> {v}"
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
entities_text = "\n".join(entity_info) if entity_info else "无实体"
|
| 308 |
+
relations_text = "\n".join(relation_info[:15]) if relation_info else "无关系"
|
| 309 |
+
|
| 310 |
+
try:
|
| 311 |
+
summary = self.summary_chain.invoke({
|
| 312 |
+
"entities": entities_text,
|
| 313 |
+
"relations": relations_text
|
| 314 |
+
})
|
| 315 |
+
return summary.strip()
|
| 316 |
+
except Exception as e:
|
| 317 |
+
print(f"❌ 社区 {community_id} 摘要生成失败: {e}")
|
| 318 |
+
return f"社区{community_id}: 包含{len(members)}个实体"
|
| 319 |
+
|
| 320 |
+
def summarize_all_communities(self, kg: KnowledgeGraph) -> Dict[int, str]:
|
| 321 |
+
"""为所有社区生成摘要"""
|
| 322 |
+
if not kg.communities:
|
| 323 |
+
print("⚠️ 未检测到社区,请先运行社区检测")
|
| 324 |
+
return {}
|
| 325 |
+
|
| 326 |
+
community_ids = set(kg.communities.values())
|
| 327 |
+
print(f"📝 开始为 {len(community_ids)} 个社区生成摘要...")
|
| 328 |
+
|
| 329 |
+
summaries = {}
|
| 330 |
+
for cid in community_ids:
|
| 331 |
+
print(f" 处理社区 {cid}...")
|
| 332 |
+
summary = self.summarize_community(kg, cid)
|
| 333 |
+
summaries[cid] = summary
|
| 334 |
+
kg.community_summaries[cid] = summary
|
| 335 |
+
|
| 336 |
+
print("✅ 所有社区摘要生成完成")
|
| 337 |
+
return summaries
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def initialize_knowledge_graph():
|
| 341 |
+
"""初始化知识图谱"""
|
| 342 |
+
return KnowledgeGraph()
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def initialize_community_summarizer():
|
| 346 |
+
"""初始化社区摘要生成器"""
|
| 347 |
+
return CommunitySummarizer()
|
local_llm_rag.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
源文件:local_llm_rag.py
|
| 5 |
+
This is a simple example of how to use LangChain to build a local LLM RAG system.
|
| 6 |
+
"""
|
| 7 |
+
import getpass
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _set_env(var: str):
|
| 12 |
+
if not os.environ.get(var):
|
| 13 |
+
os.environ[var] = getpass.getpass(f"{var}: ")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
_set_env("TAVILY_API_KEY")
|
| 17 |
+
_set_env("NOMIC_API_KEY")
|
| 18 |
+
|
| 19 |
+
# Ollama model name
|
| 20 |
+
local_llm = "mistral"
|
| 21 |
+
|
| 22 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 23 |
+
from langchain_community.document_loaders import WebBaseLoader
|
| 24 |
+
from langchain_community.vectorstores import Chroma
|
| 25 |
+
from langchain_nomic.embeddings import NomicEmbeddings
|
| 26 |
+
|
| 27 |
+
urls = [
|
| 28 |
+
"https://lilianweng.github.io/posts/2023-06-23-agent/",
|
| 29 |
+
"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
|
| 30 |
+
"https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
docs = [WebBaseLoader(url).load() for url in urls]
|
| 34 |
+
docs_list = [item for sublist in docs for item in sublist]
|
| 35 |
+
|
| 36 |
+
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
| 37 |
+
chunk_size=250, chunk_overlap=0
|
| 38 |
+
)
|
| 39 |
+
doc_splits = text_splitter.split_documents(docs_list)
|
| 40 |
+
|
| 41 |
+
# Add to vectorDB
|
| 42 |
+
vectorstore = Chroma.from_documents(
|
| 43 |
+
documents=doc_splits,
|
| 44 |
+
collection_name="rag-chroma",
|
| 45 |
+
embedding=NomicEmbeddings(model="nomic-embed-text-v1.5", inference_mode="local"),
|
| 46 |
+
)
|
| 47 |
+
retriever = vectorstore.as_retriever()
|
| 48 |
+
|
| 49 |
+
### Router
|
| 50 |
+
|
| 51 |
+
from langchain.prompts import PromptTemplate
|
| 52 |
+
from langchain_community.chat_models import ChatOllama
|
| 53 |
+
from langchain_core.output_parsers import JsonOutputParser
|
| 54 |
+
|
| 55 |
+
# LLM
|
| 56 |
+
llm = ChatOllama(model=local_llm, format="json", temperature=0)
|
| 57 |
+
|
| 58 |
+
prompt = PromptTemplate(
|
| 59 |
+
template="""You are an expert at routing a user question to a vectorstore or web search. \n
|
| 60 |
+
Use the vectorstore for questions on LLM agents, prompt engineering, and adversarial attacks. \n
|
| 61 |
+
You do not need to be stringent with the keywords in the question related to these topics. \n
|
| 62 |
+
Otherwise, use web-search. Give a binary choice 'web_search' or 'vectorstore' based on the question. \n
|
| 63 |
+
Return the a JSON with a single key 'datasource' and no premable or explanation. \n
|
| 64 |
+
Question to route: {question}""",
|
| 65 |
+
input_variables=["question"],
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
question_router = prompt | llm | JsonOutputParser()
|
| 69 |
+
question = "llm agent memory"
|
| 70 |
+
docs = retriever.get_relevant_documents(question)
|
| 71 |
+
doc_txt = docs[1].page_content
|
| 72 |
+
print(question_router.invoke({"question": question}))
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
### Generate
|
| 76 |
+
|
| 77 |
+
from langchain import hub
|
| 78 |
+
from langchain_community.chat_models import ChatOllama
|
| 79 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 80 |
+
|
| 81 |
+
# Prompt
|
| 82 |
+
prompt = hub.pull("rlm/rag-prompt")
|
| 83 |
+
|
| 84 |
+
# LLM
|
| 85 |
+
llm = ChatOllama(model=local_llm, temperature=0)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# Post-processing
|
| 89 |
+
def format_docs(docs):
|
| 90 |
+
return "\n\n".join(doc.page_content for doc in docs)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# Chain
|
| 94 |
+
rag_chain = prompt | llm | StrOutputParser()
|
| 95 |
+
|
| 96 |
+
# Run
|
| 97 |
+
question = "agent memory"
|
| 98 |
+
generation = rag_chain.invoke({"context": docs, "question": question})
|
| 99 |
+
print(generation)
|
| 100 |
+
|
| 101 |
+
### Answer Grader
|
| 102 |
+
|
| 103 |
+
# LLM
|
| 104 |
+
llm = ChatOllama(model=local_llm, format="json", temperature=0)
|
| 105 |
+
|
| 106 |
+
# Prompt
|
| 107 |
+
prompt = PromptTemplate(
|
| 108 |
+
template="""You are a grader assessing whether an answer is useful to resolve a question. \n
|
| 109 |
+
Here is the answer:
|
| 110 |
+
\n ------- \n
|
| 111 |
+
{generation}
|
| 112 |
+
\n ------- \n
|
| 113 |
+
Here is the question: {question}
|
| 114 |
+
Give a binary score 'yes' or 'no' to indicate whether the answer is useful to resolve a question. \n
|
| 115 |
+
Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.""",
|
| 116 |
+
input_variables=["generation", "question"],
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
answer_grader = prompt | llm | JsonOutputParser()
|
| 120 |
+
answer_grader.invoke({"question": question, "generation": generation})
|
| 121 |
+
|
| 122 |
+
### Question Re-writer
|
| 123 |
+
|
| 124 |
+
# LLM
|
| 125 |
+
llm = ChatOllama(model=local_llm, temperature=0)
|
| 126 |
+
|
| 127 |
+
# Prompt
|
| 128 |
+
re_write_prompt = PromptTemplate(
|
| 129 |
+
template="""You a question re-writer that converts an input question to a better version that is optimized \n
|
| 130 |
+
for vectorstore retrieval. Look at the initial and formulate an improved question. \n
|
| 131 |
+
Here is the initial question: \n\n {question}. Improved question with no preamble: \n """,
|
| 132 |
+
input_variables=["generation", "question"],
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
question_rewriter = re_write_prompt | llm | StrOutputParser()
|
| 136 |
+
question_rewriter.invoke({"question": question})
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
### Search
|
| 140 |
+
|
| 141 |
+
from langchain_community.tools.tavily_search import TavilySearchResults
|
| 142 |
+
|
| 143 |
+
web_search_tool = TavilySearchResults(k=3)
|
| 144 |
+
|
| 145 |
+
from typing import List
|
| 146 |
+
|
| 147 |
+
from typing_extensions import TypedDict
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class GraphState(TypedDict):
|
| 151 |
+
"""
|
| 152 |
+
Represents the state of our graph.
|
| 153 |
+
|
| 154 |
+
Attributes:
|
| 155 |
+
question: question
|
| 156 |
+
generation: LLM generation
|
| 157 |
+
documents: list of documents
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
question: str
|
| 161 |
+
generation: str
|
| 162 |
+
documents: List[str]
|
| 163 |
+
|
| 164 |
+
from langchain.schema import Document
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def retrieve(state):
|
| 168 |
+
"""
|
| 169 |
+
Retrieve documents
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
state (dict): The current graph state
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
state (dict): New key added to state, documents, that contains retrieved documents
|
| 176 |
+
"""
|
| 177 |
+
print("---RETRIEVE---")
|
| 178 |
+
question = state["question"]
|
| 179 |
+
|
| 180 |
+
# Retrieval
|
| 181 |
+
documents = retriever.get_relevant_documents(question)
|
| 182 |
+
return {"documents": documents, "question": question}
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def generate(state):
|
| 186 |
+
"""
|
| 187 |
+
Generate answer
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
state (dict): The current graph state
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
state (dict): New key added to state, generation, that contains LLM generation
|
| 194 |
+
"""
|
| 195 |
+
print("---GENERATE---")
|
| 196 |
+
question = state["question"]
|
| 197 |
+
documents = state["documents"]
|
| 198 |
+
|
| 199 |
+
# RAG generation
|
| 200 |
+
generation = rag_chain.invoke({"context": documents, "question": question})
|
| 201 |
+
return {"documents": documents, "question": question, "generation": generation}
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def grade_documents(state):
|
| 205 |
+
"""
|
| 206 |
+
Determines whether the retrieved documents are relevant to the question.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
state (dict): The current graph state
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
state (dict): Updates documents key with only filtered relevant documents
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
|
| 216 |
+
question = state["question"]
|
| 217 |
+
documents = state["documents"]
|
| 218 |
+
|
| 219 |
+
# Score each doc
|
| 220 |
+
filtered_docs = []
|
| 221 |
+
for d in documents:
|
| 222 |
+
score = retrieval_grader.invoke(
|
| 223 |
+
{"question": question, "document": d.page_content}
|
| 224 |
+
)
|
| 225 |
+
grade = score["score"]
|
| 226 |
+
if grade == "yes":
|
| 227 |
+
print("---GRADE: DOCUMENT RELEVANT---")
|
| 228 |
+
filtered_docs.append(d)
|
| 229 |
+
else:
|
| 230 |
+
print("---GRADE: DOCUMENT NOT RELEVANT---")
|
| 231 |
+
continue
|
| 232 |
+
return {"documents": filtered_docs, "question": question}
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def transform_query(state):
|
| 236 |
+
"""
|
| 237 |
+
Transform the query to produce a better question.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
state (dict): The current graph state
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
state (dict): Updates question key with a re-phrased question
|
| 244 |
+
"""
|
| 245 |
+
|
| 246 |
+
print("---TRANSFORM QUERY---")
|
| 247 |
+
question = state["question"]
|
| 248 |
+
documents = state["documents"]
|
| 249 |
+
|
| 250 |
+
# Re-write question
|
| 251 |
+
better_question = question_rewriter.invoke({"question": question})
|
| 252 |
+
return {"documents": documents, "question": better_question}
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def web_search(state):
|
| 256 |
+
"""
|
| 257 |
+
Web search based on the re-phrased question.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
state (dict): The current graph state
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
state (dict): Updates documents key with appended web results
|
| 264 |
+
"""
|
| 265 |
+
|
| 266 |
+
print("---WEB SEARCH---")
|
| 267 |
+
question = state["question"]
|
| 268 |
+
|
| 269 |
+
# Web search
|
| 270 |
+
docs = web_search_tool.invoke({"query": question})
|
| 271 |
+
web_results = "\n".join([d["content"] for d in docs])
|
| 272 |
+
web_results = Document(page_content=web_results)
|
| 273 |
+
|
| 274 |
+
return {"documents": web_results, "question": question}
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
### Edges ###
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def route_question(state):
|
| 281 |
+
"""
|
| 282 |
+
Route question to web search or RAG.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
state (dict): The current graph state
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
str: Next node to call
|
| 289 |
+
"""
|
| 290 |
+
|
| 291 |
+
print("---ROUTE QUESTION---")
|
| 292 |
+
question = state["question"]
|
| 293 |
+
print(question)
|
| 294 |
+
source = question_router.invoke({"question": question})
|
| 295 |
+
print(source)
|
| 296 |
+
print(source["datasource"])
|
| 297 |
+
if source["datasource"] == "web_search":
|
| 298 |
+
print("---ROUTE QUESTION TO WEB SEARCH---")
|
| 299 |
+
return "web_search"
|
| 300 |
+
elif source["datasource"] == "vectorstore":
|
| 301 |
+
print("---ROUTE QUESTION TO RAG---")
|
| 302 |
+
return "vectorstore"
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def decide_to_generate(state):
|
| 306 |
+
"""
|
| 307 |
+
Determines whether to generate an answer, or re-generate a question.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
state (dict): The current graph state
|
| 311 |
+
|
| 312 |
+
Returns:
|
| 313 |
+
str: Binary decision for next node to call
|
| 314 |
+
"""
|
| 315 |
+
|
| 316 |
+
print("---ASSESS GRADED DOCUMENTS---")
|
| 317 |
+
state["question"]
|
| 318 |
+
filtered_documents = state["documents"]
|
| 319 |
+
|
| 320 |
+
if not filtered_documents:
|
| 321 |
+
# All documents have been filtered check_relevance
|
| 322 |
+
# We will re-generate a new query
|
| 323 |
+
print(
|
| 324 |
+
"---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
|
| 325 |
+
)
|
| 326 |
+
return "transform_query"
|
| 327 |
+
else:
|
| 328 |
+
# We have relevant documents, so generate answer
|
| 329 |
+
print("---DECISION: GENERATE---")
|
| 330 |
+
return "generate"
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def grade_generation_v_documents_and_question(state):
|
| 334 |
+
"""
|
| 335 |
+
Determines whether the generation is grounded in the document and answers question.
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
state (dict): The current graph state
|
| 339 |
+
|
| 340 |
+
Returns:
|
| 341 |
+
str: Decision for next node to call
|
| 342 |
+
"""
|
| 343 |
+
|
| 344 |
+
print("---CHECK HALLUCINATIONS---")
|
| 345 |
+
question = state["question"]
|
| 346 |
+
documents = state["documents"]
|
| 347 |
+
generation = state["generation"]
|
| 348 |
+
|
| 349 |
+
score = hallucination_grader.invoke(
|
| 350 |
+
{"documents": documents, "generation": generation}
|
| 351 |
+
)
|
| 352 |
+
grade = score["score"]
|
| 353 |
+
|
| 354 |
+
# Check hallucination
|
| 355 |
+
if grade == "yes":
|
| 356 |
+
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
|
| 357 |
+
# Check question-answering
|
| 358 |
+
print("---GRADE GENERATION vs QUESTION---")
|
| 359 |
+
score = answer_grader.invoke({"question": question, "generation": generation})
|
| 360 |
+
grade = score["score"]
|
| 361 |
+
if grade == "yes":
|
| 362 |
+
print("---DECISION: GENERATION ADDRESSES QUESTION---")
|
| 363 |
+
return "useful"
|
| 364 |
+
else:
|
| 365 |
+
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
|
| 366 |
+
return "not useful"
|
| 367 |
+
else:
|
| 368 |
+
pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
|
| 369 |
+
return "not supported"
|
| 370 |
+
|
| 371 |
+
from langgraph.graph import END, StateGraph, START
|
| 372 |
+
|
| 373 |
+
workflow = StateGraph(GraphState)
|
| 374 |
+
|
| 375 |
+
# Define the nodes
|
| 376 |
+
workflow.add_node("web_search", web_search) # web search
|
| 377 |
+
workflow.add_node("retrieve", retrieve) # retrieve
|
| 378 |
+
workflow.add_node("grade_documents", grade_documents) # grade documents
|
| 379 |
+
workflow.add_node("generate", generate) # generate
|
| 380 |
+
workflow.add_node("transform_query", transform_query) # transform_query
|
| 381 |
+
|
| 382 |
+
# Build graph
|
| 383 |
+
workflow.add_conditional_edges(
|
| 384 |
+
START,
|
| 385 |
+
route_question,
|
| 386 |
+
{
|
| 387 |
+
"web_search": "web_search",
|
| 388 |
+
"vectorstore": "retrieve",
|
| 389 |
+
},
|
| 390 |
+
)
|
| 391 |
+
workflow.add_edge("web_search", "generate")
|
| 392 |
+
workflow.add_edge("retrieve", "grade_documents")
|
| 393 |
+
workflow.add_conditional_edges(
|
| 394 |
+
"grade_documents",
|
| 395 |
+
decide_to_generate,
|
| 396 |
+
{
|
| 397 |
+
"transform_query": "transform_query",
|
| 398 |
+
"generate": "generate",
|
| 399 |
+
},
|
| 400 |
+
)
|
| 401 |
+
workflow.add_edge("transform_query", "retrieve")
|
| 402 |
+
workflow.add_conditional_edges(
|
| 403 |
+
"generate",
|
| 404 |
+
grade_generation_v_documents_and_question,
|
| 405 |
+
{
|
| 406 |
+
"not supported": "generate",
|
| 407 |
+
"useful": END,
|
| 408 |
+
"not useful": "transform_query",
|
| 409 |
+
},
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
# Compile
|
| 413 |
+
app = workflow.compile()
|
| 414 |
+
|
| 415 |
+
from pprint import pprint
|
| 416 |
+
|
| 417 |
+
# Run
|
| 418 |
+
inputs = {"question": "What is the AlphaCodium paper about?"}
|
| 419 |
+
for output in app.stream(inputs):
|
| 420 |
+
for key, value in output.items():
|
| 421 |
+
# Node
|
| 422 |
+
pprint(f"Node '{key}':")
|
| 423 |
+
# Optional: print full state at each node
|
| 424 |
+
# pprint.pprint(value["keys"], indent=2, width=80, depth=None)
|
| 425 |
+
pprint("\n---\n")
|
| 426 |
+
|
| 427 |
+
# Final generation
|
| 428 |
+
pprint(value["generation"])
|
main.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
主应用程序入口
|
| 3 |
+
集成所有模块,构建工作流并运行自适应RAG系统
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from langgraph.graph import END, StateGraph, START
|
| 7 |
+
from pprint import pprint
|
| 8 |
+
|
| 9 |
+
from config import setup_environment
|
| 10 |
+
from document_processor import initialize_document_processor
|
| 11 |
+
from routers_and_graders import initialize_graders_and_router
|
| 12 |
+
from workflow_nodes import WorkflowNodes, GraphState
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class AdaptiveRAGSystem:
|
| 16 |
+
"""自适应RAG系统主类"""
|
| 17 |
+
|
| 18 |
+
def __init__(self):
|
| 19 |
+
print("初始化自适应RAG系统...")
|
| 20 |
+
|
| 21 |
+
# 设置环境和验证API密钥
|
| 22 |
+
try:
|
| 23 |
+
setup_environment()
|
| 24 |
+
|
| 25 |
+
print("✅ API密钥验证成功")
|
| 26 |
+
except ValueError as e:
|
| 27 |
+
print(f"❌ {e}")
|
| 28 |
+
raise
|
| 29 |
+
|
| 30 |
+
# 初始化文档处理器
|
| 31 |
+
print("设置文档处理器...")
|
| 32 |
+
self.doc_processor, self.vectorstore, self.retriever = initialize_document_processor()
|
| 33 |
+
|
| 34 |
+
# 初始化评分器和路由器
|
| 35 |
+
print("初始化评分器和路由器...")
|
| 36 |
+
self.graders = initialize_graders_and_router()
|
| 37 |
+
|
| 38 |
+
# 初始化工作流节点
|
| 39 |
+
print("设置工作流节点...")
|
| 40 |
+
self.workflow_nodes = WorkflowNodes(self.retriever, self.graders)
|
| 41 |
+
|
| 42 |
+
# 构建工作流
|
| 43 |
+
print("构建工作流图...")
|
| 44 |
+
self.app = self._build_workflow()
|
| 45 |
+
|
| 46 |
+
print("✅ 自适应RAG系统初始化完成!")
|
| 47 |
+
|
| 48 |
+
def _build_workflow(self):
|
| 49 |
+
"""构建工作流图"""
|
| 50 |
+
workflow = StateGraph(GraphState)
|
| 51 |
+
|
| 52 |
+
# 定义节点
|
| 53 |
+
workflow.add_node("web_search", self.workflow_nodes.web_search)
|
| 54 |
+
workflow.add_node("retrieve", self.workflow_nodes.retrieve)
|
| 55 |
+
workflow.add_node("grade_documents", self.workflow_nodes.grade_documents)
|
| 56 |
+
workflow.add_node("generate", self.workflow_nodes.generate)
|
| 57 |
+
workflow.add_node("transform_query", self.workflow_nodes.transform_query)
|
| 58 |
+
|
| 59 |
+
# 构建图
|
| 60 |
+
workflow.add_conditional_edges(
|
| 61 |
+
START,
|
| 62 |
+
self.workflow_nodes.route_question,
|
| 63 |
+
{
|
| 64 |
+
"web_search": "web_search",
|
| 65 |
+
"vectorstore": "retrieve",
|
| 66 |
+
},
|
| 67 |
+
)
|
| 68 |
+
workflow.add_edge("web_search", "generate")
|
| 69 |
+
workflow.add_edge("retrieve", "grade_documents")
|
| 70 |
+
workflow.add_conditional_edges(
|
| 71 |
+
"grade_documents",
|
| 72 |
+
self.workflow_nodes.decide_to_generate,
|
| 73 |
+
{
|
| 74 |
+
"transform_query": "transform_query",
|
| 75 |
+
"generate": "generate",
|
| 76 |
+
},
|
| 77 |
+
)
|
| 78 |
+
workflow.add_edge("transform_query", "retrieve")
|
| 79 |
+
workflow.add_conditional_edges(
|
| 80 |
+
"generate",
|
| 81 |
+
self.workflow_nodes.grade_generation_v_documents_and_question,
|
| 82 |
+
{
|
| 83 |
+
"not supported": "generate",
|
| 84 |
+
"useful": END,
|
| 85 |
+
"not useful": "transform_query",
|
| 86 |
+
},
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# 编译
|
| 90 |
+
return workflow.compile()
|
| 91 |
+
|
| 92 |
+
def query(self, question: str, verbose: bool = True):
|
| 93 |
+
"""
|
| 94 |
+
处理查询
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
question (str): 用户问题
|
| 98 |
+
verbose (bool): 是否显示详细输出
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
str: 最终答案
|
| 102 |
+
"""
|
| 103 |
+
print(f"\n🔍 处理问题: {question}")
|
| 104 |
+
print("=" * 50)
|
| 105 |
+
|
| 106 |
+
inputs = {"question": question}
|
| 107 |
+
final_generation = None
|
| 108 |
+
|
| 109 |
+
for output in self.app.stream(inputs):
|
| 110 |
+
for key, value in output.items():
|
| 111 |
+
if verbose:
|
| 112 |
+
pprint(f"节点 '{key}':")
|
| 113 |
+
# 可选:在每个节点打印完整状态
|
| 114 |
+
# pprint(value, indent=2, width=80, depth=None)
|
| 115 |
+
final_generation = value.get("generation", final_generation)
|
| 116 |
+
if verbose:
|
| 117 |
+
pprint("\n---\n")
|
| 118 |
+
|
| 119 |
+
print("🎯 最终答案:")
|
| 120 |
+
print("-" * 30)
|
| 121 |
+
print(final_generation)
|
| 122 |
+
print("=" * 50)
|
| 123 |
+
|
| 124 |
+
return final_generation
|
| 125 |
+
|
| 126 |
+
def interactive_mode(self):
|
| 127 |
+
"""交互模式,允许用户持续提问"""
|
| 128 |
+
print("\n🤖 欢迎使用自适应RAG系统!")
|
| 129 |
+
print("💡 输入问题开始对话,输入 'quit' 或 'exit' 退出")
|
| 130 |
+
print("-" * 50)
|
| 131 |
+
|
| 132 |
+
while True:
|
| 133 |
+
try:
|
| 134 |
+
question = input("\n❓ 请输入您的问题: ").strip()
|
| 135 |
+
|
| 136 |
+
if question.lower() in ['quit', 'exit', '退出', 'q']:
|
| 137 |
+
print("👋 感谢使用,再见!")
|
| 138 |
+
break
|
| 139 |
+
|
| 140 |
+
if not question:
|
| 141 |
+
print("⚠️ 请输入一个有效的问题")
|
| 142 |
+
continue
|
| 143 |
+
|
| 144 |
+
self.query(question)
|
| 145 |
+
|
| 146 |
+
except KeyboardInterrupt:
|
| 147 |
+
print("\n👋 感谢使用,再见!")
|
| 148 |
+
break
|
| 149 |
+
except Exception as e:
|
| 150 |
+
print(f"❌ 发生错误: {e}")
|
| 151 |
+
print("请重试或输入 'quit' 退出")
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def main():
|
| 155 |
+
"""主函数"""
|
| 156 |
+
try:
|
| 157 |
+
# 初始化系统
|
| 158 |
+
rag_system = AdaptiveRAGSystem()
|
| 159 |
+
|
| 160 |
+
# 测试查询
|
| 161 |
+
# test_question = "AlphaCodium论文讲的是什么?"
|
| 162 |
+
test_question = "解释embedding嵌入的原理,最好列举实现过程的具体步骤"
|
| 163 |
+
rag_system.query(test_question)
|
| 164 |
+
|
| 165 |
+
# 启动交互模式
|
| 166 |
+
rag_system.interactive_mode()
|
| 167 |
+
|
| 168 |
+
except Exception as e:
|
| 169 |
+
print(f"❌ 系统初始化失败: {e}")
|
| 170 |
+
print("请检查配置和依赖是否正确安装")
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
if __name__ == "__main__":
|
| 174 |
+
main()
|
main_graphrag.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GraphRAG集成示例
|
| 3 |
+
展示如何在自适应RAG系统中使用知识图谱功能
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from pprint import pprint
|
| 8 |
+
|
| 9 |
+
from config import (
|
| 10 |
+
setup_environment,
|
| 11 |
+
ENABLE_GRAPHRAG,
|
| 12 |
+
GRAPHRAG_INDEX_PATH,
|
| 13 |
+
GRAPHRAG_BATCH_SIZE
|
| 14 |
+
)
|
| 15 |
+
from document_processor import initialize_document_processor
|
| 16 |
+
from graph_indexer import initialize_graph_indexer
|
| 17 |
+
from graph_retriever import initialize_graph_retriever
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class AdaptiveRAGWithGraph:
|
| 21 |
+
"""集成GraphRAG的自适应RAG系统"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, enable_graphrag=True, rebuild_index=False):
|
| 24 |
+
print("🚀 初始化集成GraphRAG的自适应RAG系统...")
|
| 25 |
+
print("="*60)
|
| 26 |
+
|
| 27 |
+
# 设置环境
|
| 28 |
+
try:
|
| 29 |
+
setup_environment()
|
| 30 |
+
print("✅ 环境配置完成")
|
| 31 |
+
except ValueError as e:
|
| 32 |
+
print(f"❌ {e}")
|
| 33 |
+
raise
|
| 34 |
+
|
| 35 |
+
# 初始化文档处理器
|
| 36 |
+
print("\n📚 初始化文档处理器...")
|
| 37 |
+
self.doc_processor, self.vectorstore, self.retriever, self.doc_splits = \
|
| 38 |
+
initialize_document_processor()
|
| 39 |
+
|
| 40 |
+
# GraphRAG组件
|
| 41 |
+
self.enable_graphrag = enable_graphrag and ENABLE_GRAPHRAG
|
| 42 |
+
self.graph_indexer = None
|
| 43 |
+
self.graph_retriever = None
|
| 44 |
+
self.knowledge_graph = None
|
| 45 |
+
|
| 46 |
+
if self.enable_graphrag:
|
| 47 |
+
self._setup_graphrag(rebuild_index)
|
| 48 |
+
|
| 49 |
+
print("\n" + "="*60)
|
| 50 |
+
print("✅ 系统初始化完成!")
|
| 51 |
+
print("="*60)
|
| 52 |
+
|
| 53 |
+
def _setup_graphrag(self, rebuild_index=False):
|
| 54 |
+
"""设置GraphRAG组件"""
|
| 55 |
+
print("\n🔷 设置GraphRAG组件...")
|
| 56 |
+
|
| 57 |
+
# 初始化索引器
|
| 58 |
+
self.graph_indexer = initialize_graph_indexer()
|
| 59 |
+
|
| 60 |
+
# 检查是否已有索引
|
| 61 |
+
index_exists = os.path.exists(GRAPHRAG_INDEX_PATH)
|
| 62 |
+
|
| 63 |
+
if index_exists and not rebuild_index:
|
| 64 |
+
print(f"📂 发现现有索引: {GRAPHRAG_INDEX_PATH}")
|
| 65 |
+
print(" 加载现有索引...")
|
| 66 |
+
self.knowledge_graph = self.graph_indexer.load_index(GRAPHRAG_INDEX_PATH)
|
| 67 |
+
else:
|
| 68 |
+
if rebuild_index:
|
| 69 |
+
print("🔄 重新构建索引...")
|
| 70 |
+
else:
|
| 71 |
+
print("📝 首次构建索引...")
|
| 72 |
+
|
| 73 |
+
# 构建索引
|
| 74 |
+
self.knowledge_graph = self.graph_indexer.index_documents(
|
| 75 |
+
documents=self.doc_splits,
|
| 76 |
+
batch_size=GRAPHRAG_BATCH_SIZE,
|
| 77 |
+
save_path=GRAPHRAG_INDEX_PATH
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# 初始化检索器
|
| 81 |
+
self.graph_retriever = initialize_graph_retriever(self.knowledge_graph)
|
| 82 |
+
print("✅ GraphRAG组件设置完成")
|
| 83 |
+
|
| 84 |
+
def query_vector_only(self, question: str) -> str:
|
| 85 |
+
"""仅使用向量检索"""
|
| 86 |
+
print(f"\n{'='*60}")
|
| 87 |
+
print(f"🔍 向量检索模式")
|
| 88 |
+
print(f"问题: {question}")
|
| 89 |
+
print(f"{'='*60}")
|
| 90 |
+
|
| 91 |
+
docs = self.retriever.get_relevant_documents(question)
|
| 92 |
+
|
| 93 |
+
print(f"\n📄 检索到 {len(docs)} 个文档片段:")
|
| 94 |
+
for i, doc in enumerate(docs[:3], 1):
|
| 95 |
+
print(f"\n片段 {i}:")
|
| 96 |
+
print(f"{doc.page_content[:200]}...")
|
| 97 |
+
|
| 98 |
+
return self.doc_processor.format_docs(docs)
|
| 99 |
+
|
| 100 |
+
def query_graph_local(self, question: str) -> str:
|
| 101 |
+
"""使用图谱本地查询"""
|
| 102 |
+
if not self.enable_graphrag:
|
| 103 |
+
return "GraphRAG未启用"
|
| 104 |
+
|
| 105 |
+
print(f"\n{'='*60}")
|
| 106 |
+
print(f"🔎 图谱本地查询模式")
|
| 107 |
+
print(f"问题: {question}")
|
| 108 |
+
print(f"{'='*60}")
|
| 109 |
+
|
| 110 |
+
answer = self.graph_retriever.local_query(question)
|
| 111 |
+
|
| 112 |
+
print(f"\n💡 答案:")
|
| 113 |
+
print(answer)
|
| 114 |
+
|
| 115 |
+
return answer
|
| 116 |
+
|
| 117 |
+
def query_graph_global(self, question: str) -> str:
|
| 118 |
+
"""使用图谱全局查询"""
|
| 119 |
+
if not self.enable_graphrag:
|
| 120 |
+
return "GraphRAG未启用"
|
| 121 |
+
|
| 122 |
+
print(f"\n{'='*60}")
|
| 123 |
+
print(f"🌍 图谱全局查询模式")
|
| 124 |
+
print(f"问题: {question}")
|
| 125 |
+
print(f"{'='*60}")
|
| 126 |
+
|
| 127 |
+
answer = self.graph_retriever.global_query(question)
|
| 128 |
+
|
| 129 |
+
print(f"\n💡 答案:")
|
| 130 |
+
print(answer)
|
| 131 |
+
|
| 132 |
+
return answer
|
| 133 |
+
|
| 134 |
+
def query_hybrid(self, question: str) -> dict:
|
| 135 |
+
"""混合查询:向量 + 图谱"""
|
| 136 |
+
if not self.enable_graphrag:
|
| 137 |
+
return {"error": "GraphRAG未启用"}
|
| 138 |
+
|
| 139 |
+
print(f"\n{'='*60}")
|
| 140 |
+
print(f"🔀 混合查询模式")
|
| 141 |
+
print(f"问题: {question}")
|
| 142 |
+
print(f"{'='*60}")
|
| 143 |
+
|
| 144 |
+
# 向量检索
|
| 145 |
+
vector_docs = self.retriever.get_relevant_documents(question)
|
| 146 |
+
vector_context = self.doc_processor.format_docs(vector_docs[:3])
|
| 147 |
+
|
| 148 |
+
# 图谱查询
|
| 149 |
+
graph_results = self.graph_retriever.hybrid_query(question)
|
| 150 |
+
|
| 151 |
+
result = {
|
| 152 |
+
"question": question,
|
| 153 |
+
"vector_retrieval": {
|
| 154 |
+
"doc_count": len(vector_docs),
|
| 155 |
+
"context": vector_context[:500] + "..." if len(vector_context) > 500 else vector_context
|
| 156 |
+
},
|
| 157 |
+
"graph_local": graph_results["local"],
|
| 158 |
+
"graph_global": graph_results["global"]
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
print("\n📊 结果汇总:")
|
| 162 |
+
print(f" • 向量检索: {len(vector_docs)} 个文档")
|
| 163 |
+
print(f" • 图谱本地查询完成")
|
| 164 |
+
print(f" • 图谱全局查询完成")
|
| 165 |
+
|
| 166 |
+
return result
|
| 167 |
+
|
| 168 |
+
def query_smart(self, question: str) -> str:
|
| 169 |
+
"""智能查询:自动选择最佳策略"""
|
| 170 |
+
if not self.enable_graphrag:
|
| 171 |
+
return self.query_vector_only(question)
|
| 172 |
+
|
| 173 |
+
print(f"\n{'='*60}")
|
| 174 |
+
print(f"🧠 智能查询模式")
|
| 175 |
+
print(f"问题: {question}")
|
| 176 |
+
print(f"{'='*60}")
|
| 177 |
+
|
| 178 |
+
answer = self.graph_retriever.smart_query(question)
|
| 179 |
+
|
| 180 |
+
print(f"\n💡 答案:")
|
| 181 |
+
print(answer)
|
| 182 |
+
|
| 183 |
+
return answer
|
| 184 |
+
|
| 185 |
+
def get_graph_statistics(self):
|
| 186 |
+
"""获取知识图谱统计信息"""
|
| 187 |
+
if not self.enable_graphrag or not self.knowledge_graph:
|
| 188 |
+
print("GraphRAG未启用或图谱未构建")
|
| 189 |
+
return
|
| 190 |
+
|
| 191 |
+
stats = self.knowledge_graph.get_statistics()
|
| 192 |
+
|
| 193 |
+
print("\n" + "="*60)
|
| 194 |
+
print("📊 知识图谱统计信息")
|
| 195 |
+
print("="*60)
|
| 196 |
+
print(f"节点数: {stats['num_nodes']}")
|
| 197 |
+
print(f"边数: {stats['num_edges']}")
|
| 198 |
+
print(f"社区数: {stats['num_communities']}")
|
| 199 |
+
print(f"图密度: {stats['density']:.4f}")
|
| 200 |
+
print("\n实体类型分布:")
|
| 201 |
+
for etype, count in stats['entity_types'].items():
|
| 202 |
+
print(f" • {etype}: {count}")
|
| 203 |
+
print("="*60)
|
| 204 |
+
|
| 205 |
+
return stats
|
| 206 |
+
|
| 207 |
+
def interactive_mode(self):
|
| 208 |
+
"""交互模式"""
|
| 209 |
+
print("\n" + "="*60)
|
| 210 |
+
print("🤖 欢迎使用GraphRAG增强的自适应RAG系统!")
|
| 211 |
+
print("="*60)
|
| 212 |
+
print("\n查询模式:")
|
| 213 |
+
print(" 1️⃣ vector - 仅向量检索")
|
| 214 |
+
print(" 2️⃣ local - 图谱本地查询")
|
| 215 |
+
print(" 3️⃣ global - 图谱全局查询")
|
| 216 |
+
print(" 4️⃣ hybrid - 混合查询")
|
| 217 |
+
print(" 5️⃣ smart - 智能查询(推荐)")
|
| 218 |
+
print(" 6️⃣ stats - 显示图谱统计")
|
| 219 |
+
print(" 7️⃣ quit - 退出")
|
| 220 |
+
print("-"*60)
|
| 221 |
+
|
| 222 |
+
while True:
|
| 223 |
+
try:
|
| 224 |
+
mode = input("\n选择模式 (1-7): ").strip()
|
| 225 |
+
|
| 226 |
+
if mode in ['7', 'quit', 'exit', '退出', 'q']:
|
| 227 |
+
print("👋 感谢使用,再见!")
|
| 228 |
+
break
|
| 229 |
+
|
| 230 |
+
if mode in ['6', 'stats']:
|
| 231 |
+
self.get_graph_statistics()
|
| 232 |
+
continue
|
| 233 |
+
|
| 234 |
+
question = input("❓ 请输入问题: ").strip()
|
| 235 |
+
|
| 236 |
+
if not question:
|
| 237 |
+
print("⚠️ 请输入有效问题")
|
| 238 |
+
continue
|
| 239 |
+
|
| 240 |
+
if mode in ['1', 'vector']:
|
| 241 |
+
self.query_vector_only(question)
|
| 242 |
+
elif mode in ['2', 'local']:
|
| 243 |
+
self.query_graph_local(question)
|
| 244 |
+
elif mode in ['3', 'global']:
|
| 245 |
+
self.query_graph_global(question)
|
| 246 |
+
elif mode in ['4', 'hybrid']:
|
| 247 |
+
result = self.query_hybrid(question)
|
| 248 |
+
pprint(result)
|
| 249 |
+
else: # 默认智能模式
|
| 250 |
+
self.query_smart(question)
|
| 251 |
+
|
| 252 |
+
except KeyboardInterrupt:
|
| 253 |
+
print("\n👋 感谢使用,再见!")
|
| 254 |
+
break
|
| 255 |
+
except Exception as e:
|
| 256 |
+
print(f"❌ 发生错误: {e}")
|
| 257 |
+
print("请重试或输入 'quit' 退出")
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def main():
|
| 261 |
+
"""主函数"""
|
| 262 |
+
try:
|
| 263 |
+
# 初始化系统(首次运行设置rebuild_index=True)
|
| 264 |
+
rag_system = AdaptiveRAGWithGraph(
|
| 265 |
+
enable_graphrag=True,
|
| 266 |
+
rebuild_index=False # 设为True重新构建索引
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# 显示图谱统计
|
| 270 |
+
rag_system.get_graph_statistics()
|
| 271 |
+
|
| 272 |
+
# 测试查询
|
| 273 |
+
print("\n" + "="*60)
|
| 274 |
+
print("🧪 测试查询示例")
|
| 275 |
+
print("="*60)
|
| 276 |
+
|
| 277 |
+
# 示例1: 本地查询
|
| 278 |
+
rag_system.query_graph_local("LLM Agent的主要组成部分是什么?")
|
| 279 |
+
|
| 280 |
+
# 示例2: 全局查询
|
| 281 |
+
rag_system.query_graph_global("这些文档主要讨论了什么主题?")
|
| 282 |
+
|
| 283 |
+
# 启动交互模式
|
| 284 |
+
rag_system.interactive_mode()
|
| 285 |
+
|
| 286 |
+
except Exception as e:
|
| 287 |
+
print(f"❌ 系统初始化失败: {e}")
|
| 288 |
+
import traceback
|
| 289 |
+
traceback.print_exc()
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
if __name__ == "__main__":
|
| 293 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 自适应RAG系统依赖文件
|
| 2 |
+
# 运行以下命令安装: pip install -r requirements.txt
|
| 3 |
+
|
| 4 |
+
# 核心框架
|
| 5 |
+
langchain>=0.1.0
|
| 6 |
+
langgraph>=0.0.40
|
| 7 |
+
langchain-community>=0.0.20
|
| 8 |
+
langchain-core>=0.1.0
|
| 9 |
+
|
| 10 |
+
# LLM集成
|
| 11 |
+
langchain-ollama>=0.1.0
|
| 12 |
+
|
| 13 |
+
# 向量数据库和嵌入
|
| 14 |
+
chromadb>=0.4.0
|
| 15 |
+
sentence-transformers>=2.2.0
|
| 16 |
+
torch>=2.0.0
|
| 17 |
+
transformers>=4.30.0
|
| 18 |
+
|
| 19 |
+
# 文档处理
|
| 20 |
+
tiktoken>=0.5.0
|
| 21 |
+
beautifulsoup4>=4.12.0
|
| 22 |
+
requests>=2.31.0
|
| 23 |
+
|
| 24 |
+
# 网络搜索
|
| 25 |
+
tavily-python>=0.3.0
|
| 26 |
+
|
| 27 |
+
# 数据处理
|
| 28 |
+
numpy>=1.24.0,<2.0 # 避免NumPy 2.x兼容性问题
|
| 29 |
+
pandas>=2.0.0
|
| 30 |
+
|
| 31 |
+
# 工具库
|
| 32 |
+
python-dotenv>=1.0.0
|
| 33 |
+
pydantic>=2.0.0
|
| 34 |
+
typing-extensions>=4.0.0
|
| 35 |
+
|
| 36 |
+
# 开发工具(可选)
|
| 37 |
+
jupyter>=1.0.0
|
| 38 |
+
ipykernel>=6.0.0
|
| 39 |
+
matplotlib>=3.7.0
|
| 40 |
+
seaborn>=0.12.0
|
| 41 |
+
|
| 42 |
+
# GraphRAG相关(可选)
|
| 43 |
+
networkx>=3.1 # 图结构处理
|
| 44 |
+
python-louvain>=0.16 # 社区检测
|
requirements_gpu.txt
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GPU优化版依赖文件 - 针对RTX 4090 Linux环境
|
| 2 |
+
# 安装命令: pip install -r requirements_gpu.txt
|
| 3 |
+
|
| 4 |
+
# 核心框架
|
| 5 |
+
langchain>=0.1.0
|
| 6 |
+
langgraph>=0.0.40
|
| 7 |
+
langchain-community>=0.0.20
|
| 8 |
+
langchain-core>=0.1.0
|
| 9 |
+
|
| 10 |
+
# LLM集成
|
| 11 |
+
langchain-ollama>=0.1.0
|
| 12 |
+
|
| 13 |
+
# 向量数据库和嵌入(GPU优化版本)
|
| 14 |
+
chromadb>=0.4.0
|
| 15 |
+
sentence-transformers>=2.2.0
|
| 16 |
+
|
| 17 |
+
# PyTorch GPU版本(CUDA 11.8支持)
|
| 18 |
+
--index-url https://download.pytorch.org/whl/cu118
|
| 19 |
+
torch>=2.0.0+cu118
|
| 20 |
+
torchvision>=0.15.0+cu118
|
| 21 |
+
torchaudio>=2.0.0+cu118
|
| 22 |
+
|
| 23 |
+
# Transformers和加速库
|
| 24 |
+
transformers>=4.30.0
|
| 25 |
+
accelerate>=0.20.0
|
| 26 |
+
|
| 27 |
+
# GPU加速向量搜索
|
| 28 |
+
faiss-gpu>=1.7.4
|
| 29 |
+
|
| 30 |
+
# CUDA支持库
|
| 31 |
+
cupy-cuda12x>=12.0.0
|
| 32 |
+
|
| 33 |
+
# 文档处理
|
| 34 |
+
tiktoken>=0.5.0
|
| 35 |
+
beautifulsoup4>=4.12.0
|
| 36 |
+
requests>=2.31.0
|
| 37 |
+
|
| 38 |
+
# 网络搜索
|
| 39 |
+
tavily-python>=0.3.0
|
| 40 |
+
|
| 41 |
+
# 数据处理(兼容版本)
|
| 42 |
+
numpy>=1.24.0,<2.0
|
| 43 |
+
pandas>=2.0.0
|
| 44 |
+
|
| 45 |
+
# 工具库
|
| 46 |
+
python-dotenv>=1.0.0
|
| 47 |
+
pydantic>=2.0.0
|
| 48 |
+
typing-extensions>=4.0.0
|
| 49 |
+
|
| 50 |
+
# 系统监控
|
| 51 |
+
gputil>=1.4.0
|
| 52 |
+
psutil>=5.9.0
|
| 53 |
+
|
| 54 |
+
# 可选:Jupyter支持
|
| 55 |
+
jupyter>=1.0.0
|
| 56 |
+
ipykernel>=6.0.0
|
requirements_graphrag.txt
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GraphRAG额外依赖
|
| 2 |
+
# 在原有requirements.txt基础上添加这些包
|
| 3 |
+
# 安装命令: pip install -r requirements_graphrag.txt
|
| 4 |
+
|
| 5 |
+
# 图数据库
|
| 6 |
+
neo4j>=5.14.0
|
| 7 |
+
py2neo>=2021.2.3
|
| 8 |
+
|
| 9 |
+
# 或者使用更轻量的选项
|
| 10 |
+
networkx>=3.1
|
| 11 |
+
python-louvain>=0.16 # 社区检测
|
| 12 |
+
|
| 13 |
+
# 图谱处理
|
| 14 |
+
graspologic>=3.3.0 # 层次化社区检测
|
| 15 |
+
leidenalg>=0.10.0 # 更好的社区检测算法
|
| 16 |
+
|
| 17 |
+
# GraphRAG核心(可选,如果使用Microsoft官方实现)
|
| 18 |
+
# graphrag>=0.1.0
|
| 19 |
+
|
| 20 |
+
# 实体识别增强
|
| 21 |
+
spacy>=3.7.0
|
| 22 |
+
# 下载模型: python -m spacy download zh_core_web_sm
|
| 23 |
+
# 下载模型: python -m spacy download en_core_web_sm
|
| 24 |
+
|
| 25 |
+
# 图可视化(可选)
|
| 26 |
+
pyvis>=0.3.2
|
| 27 |
+
plotly>=5.18.0
|
| 28 |
+
|
| 29 |
+
# 缓存和性能优化
|
| 30 |
+
diskcache>=5.6.0
|
| 31 |
+
joblib>=1.3.0
|
reranker.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
向量重排模块
|
| 3 |
+
实现多种重排策略以提高检索质量
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
from typing import List, Tuple, Dict
|
| 9 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 10 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 11 |
+
import re
|
| 12 |
+
from collections import Counter
|
| 13 |
+
import math
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class DocumentReranker:
|
| 17 |
+
"""文档重排器基类"""
|
| 18 |
+
|
| 19 |
+
def __init__(self):
|
| 20 |
+
self.name = "BaseReranker"
|
| 21 |
+
|
| 22 |
+
def rerank(self, query: str, documents: List[dict], top_k: int = 5) -> List[Tuple[dict, float]]:
|
| 23 |
+
"""重排文档并返回top_k结果"""
|
| 24 |
+
raise NotImplementedError
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class TFIDFReranker(DocumentReranker):
|
| 28 |
+
"""基于TF-IDF的重排器"""
|
| 29 |
+
|
| 30 |
+
def __init__(self):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.name = "TFIDFReranker"
|
| 33 |
+
self.vectorizer = TfidfVectorizer(stop_words='english', max_features=5000)
|
| 34 |
+
|
| 35 |
+
def rerank(self, query: str, documents: List[dict], top_k: int = 5) -> List[Tuple[dict, float]]:
|
| 36 |
+
"""使用TF-IDF重新排序文档"""
|
| 37 |
+
if not documents:
|
| 38 |
+
return []
|
| 39 |
+
|
| 40 |
+
# 提取文档内容
|
| 41 |
+
doc_texts = [doc.page_content if hasattr(doc, 'page_content') else str(doc) for doc in documents]
|
| 42 |
+
all_texts = [query] + doc_texts
|
| 43 |
+
|
| 44 |
+
# 计算TF-IDF矩阵
|
| 45 |
+
tfidf_matrix = self.vectorizer.fit_transform(all_texts)
|
| 46 |
+
query_vec = tfidf_matrix[0]
|
| 47 |
+
doc_vecs = tfidf_matrix[1:]
|
| 48 |
+
|
| 49 |
+
# 计算相似度
|
| 50 |
+
similarities = cosine_similarity(query_vec, doc_vecs).flatten()
|
| 51 |
+
|
| 52 |
+
# 排序并返回top_k
|
| 53 |
+
ranked_indices = np.argsort(similarities)[::-1]
|
| 54 |
+
results = []
|
| 55 |
+
for i in ranked_indices[:top_k]:
|
| 56 |
+
results.append((documents[i], float(similarities[i])))
|
| 57 |
+
|
| 58 |
+
return results
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class BM25Reranker(DocumentReranker):
|
| 62 |
+
"""基于BM25算法的重排器"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, k1: float = 1.5, b: float = 0.75):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.name = "BM25Reranker"
|
| 67 |
+
self.k1 = k1
|
| 68 |
+
self.b = b
|
| 69 |
+
|
| 70 |
+
def _tokenize(self, text: str) -> List[str]:
|
| 71 |
+
"""简单分词"""
|
| 72 |
+
return re.findall(r'\b\w+\b', text.lower())
|
| 73 |
+
|
| 74 |
+
def _compute_idf(self, documents: List[str], query_terms: List[str]) -> Dict[str, float]:
|
| 75 |
+
"""计算IDF值"""
|
| 76 |
+
N = len(documents)
|
| 77 |
+
idf = {}
|
| 78 |
+
|
| 79 |
+
for term in query_terms:
|
| 80 |
+
df = sum(1 for doc in documents if term in self._tokenize(doc))
|
| 81 |
+
idf[term] = math.log((N - df + 0.5) / (df + 0.5))
|
| 82 |
+
|
| 83 |
+
return idf
|
| 84 |
+
|
| 85 |
+
def _bm25_score(self, query_terms: List[str], document: str, avg_doc_len: float, idf: Dict[str, float]) -> float:
|
| 86 |
+
"""计算BM25分数"""
|
| 87 |
+
doc_terms = self._tokenize(document)
|
| 88 |
+
doc_len = len(doc_terms)
|
| 89 |
+
term_freq = Counter(doc_terms)
|
| 90 |
+
|
| 91 |
+
score = 0.0
|
| 92 |
+
for term in query_terms:
|
| 93 |
+
if term in term_freq:
|
| 94 |
+
tf = term_freq[term]
|
| 95 |
+
score += idf.get(term, 0) * (tf * (self.k1 + 1)) / (
|
| 96 |
+
tf + self.k1 * (1 - self.b + self.b * doc_len / avg_doc_len)
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
return score
|
| 100 |
+
|
| 101 |
+
def rerank(self, query: str, documents: List[dict], top_k: int = 5) -> List[Tuple[dict, float]]:
|
| 102 |
+
"""使用BM25重新排序文档"""
|
| 103 |
+
if not documents:
|
| 104 |
+
return []
|
| 105 |
+
|
| 106 |
+
query_terms = self._tokenize(query)
|
| 107 |
+
doc_texts = [doc.page_content if hasattr(doc, 'page_content') else str(doc) for doc in documents]
|
| 108 |
+
|
| 109 |
+
# 计算平均文档长度
|
| 110 |
+
avg_doc_len = sum(len(self._tokenize(doc)) for doc in doc_texts) / len(doc_texts)
|
| 111 |
+
|
| 112 |
+
# 计算IDF
|
| 113 |
+
idf = self._compute_idf(doc_texts, query_terms)
|
| 114 |
+
|
| 115 |
+
# 计算BM25分数
|
| 116 |
+
scores = []
|
| 117 |
+
for doc_text in doc_texts:
|
| 118 |
+
score = self._bm25_score(query_terms, doc_text, avg_doc_len, idf)
|
| 119 |
+
scores.append(score)
|
| 120 |
+
|
| 121 |
+
# 排序并返回top_k
|
| 122 |
+
ranked_indices = np.argsort(scores)[::-1]
|
| 123 |
+
results = []
|
| 124 |
+
for i in ranked_indices[:top_k]:
|
| 125 |
+
results.append((documents[i], float(scores[i])))
|
| 126 |
+
|
| 127 |
+
return results
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class SemanticReranker(DocumentReranker):
|
| 131 |
+
"""基于语义相似度的重排器"""
|
| 132 |
+
|
| 133 |
+
def __init__(self, embeddings_model):
|
| 134 |
+
super().__init__()
|
| 135 |
+
self.name = "SemanticReranker"
|
| 136 |
+
self.embeddings_model = embeddings_model
|
| 137 |
+
|
| 138 |
+
def rerank(self, query: str, documents: List[dict], top_k: int = 5) -> List[Tuple[dict, float]]:
|
| 139 |
+
"""使用语义相似度重新排序文档"""
|
| 140 |
+
if not documents:
|
| 141 |
+
return []
|
| 142 |
+
|
| 143 |
+
# 获取查询嵌入
|
| 144 |
+
query_embedding = self.embeddings_model.embed_query(query)
|
| 145 |
+
|
| 146 |
+
# 获取文档嵌入
|
| 147 |
+
doc_texts = [doc.page_content if hasattr(doc, 'page_content') else str(doc) for doc in documents]
|
| 148 |
+
doc_embeddings = self.embeddings_model.embed_documents(doc_texts)
|
| 149 |
+
|
| 150 |
+
# 计算余弦相似度
|
| 151 |
+
similarities = []
|
| 152 |
+
for doc_emb in doc_embeddings:
|
| 153 |
+
sim = cosine_similarity([query_embedding], [doc_emb])[0][0]
|
| 154 |
+
similarities.append(sim)
|
| 155 |
+
|
| 156 |
+
# 排序并返回top_k
|
| 157 |
+
ranked_indices = np.argsort(similarities)[::-1]
|
| 158 |
+
results = []
|
| 159 |
+
for i in ranked_indices[:top_k]:
|
| 160 |
+
results.append((documents[i], float(similarities[i])))
|
| 161 |
+
|
| 162 |
+
return results
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class HybridReranker(DocumentReranker):
|
| 166 |
+
"""混合重排器,融合多种策略"""
|
| 167 |
+
|
| 168 |
+
def __init__(self, embeddings_model, weights: Dict[str, float] = None):
|
| 169 |
+
super().__init__()
|
| 170 |
+
self.name = "HybridReranker"
|
| 171 |
+
|
| 172 |
+
# 初始化各种重排器
|
| 173 |
+
self.tfidf_reranker = TFIDFReranker()
|
| 174 |
+
self.bm25_reranker = BM25Reranker()
|
| 175 |
+
self.semantic_reranker = SemanticReranker(embeddings_model)
|
| 176 |
+
|
| 177 |
+
# 设置权重
|
| 178 |
+
self.weights = weights or {
|
| 179 |
+
'tfidf': 0.3,
|
| 180 |
+
'bm25': 0.3,
|
| 181 |
+
'semantic': 0.4
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
def rerank(self, query: str, documents: List[dict], top_k: int = 5) -> List[Tuple[dict, float]]:
|
| 185 |
+
"""使用混合策略重新排序文档"""
|
| 186 |
+
if not documents:
|
| 187 |
+
return []
|
| 188 |
+
|
| 189 |
+
# 获取各种重排结果
|
| 190 |
+
tfidf_results = self.tfidf_reranker.rerank(query, documents, len(documents))
|
| 191 |
+
bm25_results = self.bm25_reranker.rerank(query, documents, len(documents))
|
| 192 |
+
semantic_results = self.semantic_reranker.rerank(query, documents, len(documents))
|
| 193 |
+
|
| 194 |
+
# 创建文档到分数的映射
|
| 195 |
+
doc_scores = {}
|
| 196 |
+
for doc in documents:
|
| 197 |
+
doc_id = id(doc)
|
| 198 |
+
doc_scores[doc_id] = {'doc': doc, 'tfidf': 0, 'bm25': 0, 'semantic': 0}
|
| 199 |
+
|
| 200 |
+
# 填充各种分数
|
| 201 |
+
for doc, score in tfidf_results:
|
| 202 |
+
doc_scores[id(doc)]['tfidf'] = score
|
| 203 |
+
|
| 204 |
+
for doc, score in bm25_results:
|
| 205 |
+
doc_scores[id(doc)]['bm25'] = score
|
| 206 |
+
|
| 207 |
+
for doc, score in semantic_results:
|
| 208 |
+
doc_scores[id(doc)]['semantic'] = score
|
| 209 |
+
|
| 210 |
+
# 归一化分数
|
| 211 |
+
for score_type in ['tfidf', 'bm25', 'semantic']:
|
| 212 |
+
scores = [info[score_type] for info in doc_scores.values()]
|
| 213 |
+
if max(scores) > 0:
|
| 214 |
+
max_score = max(scores)
|
| 215 |
+
for doc_id in doc_scores:
|
| 216 |
+
doc_scores[doc_id][score_type] /= max_score
|
| 217 |
+
|
| 218 |
+
# 计算综合分数
|
| 219 |
+
final_scores = []
|
| 220 |
+
for doc_id, info in doc_scores.items():
|
| 221 |
+
combined_score = (
|
| 222 |
+
self.weights['tfidf'] * info['tfidf'] +
|
| 223 |
+
self.weights['bm25'] * info['bm25'] +
|
| 224 |
+
self.weights['semantic'] * info['semantic']
|
| 225 |
+
)
|
| 226 |
+
final_scores.append((info['doc'], combined_score))
|
| 227 |
+
|
| 228 |
+
# 排序并返回top_k
|
| 229 |
+
final_scores.sort(key=lambda x: x[1], reverse=True)
|
| 230 |
+
return final_scores[:top_k]
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class DiversityReranker(DocumentReranker):
|
| 234 |
+
"""多样性重排器,避免结果重复"""
|
| 235 |
+
|
| 236 |
+
def __init__(self, embeddings_model, diversity_lambda: float = 0.5):
|
| 237 |
+
super().__init__()
|
| 238 |
+
self.name = "DiversityReranker"
|
| 239 |
+
self.embeddings_model = embeddings_model
|
| 240 |
+
self.diversity_lambda = diversity_lambda
|
| 241 |
+
|
| 242 |
+
def _calculate_diversity_penalty(self, candidate_doc: str, selected_docs: List[str]) -> float:
|
| 243 |
+
"""计算多样性惩罚"""
|
| 244 |
+
if not selected_docs:
|
| 245 |
+
return 0.0
|
| 246 |
+
|
| 247 |
+
candidate_emb = self.embeddings_model.embed_documents([candidate_doc])[0]
|
| 248 |
+
selected_embs = self.embeddings_model.embed_documents(selected_docs)
|
| 249 |
+
|
| 250 |
+
max_similarity = 0.0
|
| 251 |
+
for selected_emb in selected_embs:
|
| 252 |
+
sim = cosine_similarity([candidate_emb], [selected_emb])[0][0]
|
| 253 |
+
max_similarity = max(max_similarity, sim)
|
| 254 |
+
|
| 255 |
+
return max_similarity
|
| 256 |
+
|
| 257 |
+
def rerank(self, query: str, documents: List[dict], top_k: int = 5) -> List[Tuple[dict, float]]:
|
| 258 |
+
"""使用多样性策略重新排序文档"""
|
| 259 |
+
if not documents:
|
| 260 |
+
return []
|
| 261 |
+
|
| 262 |
+
# 首先使用语义相似度获取初始排序
|
| 263 |
+
semantic_results = SemanticReranker(self.embeddings_model).rerank(
|
| 264 |
+
query, documents, len(documents)
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# MMR (Maximal Marginal Relevance) 算法
|
| 268 |
+
selected_docs = []
|
| 269 |
+
selected_texts = []
|
| 270 |
+
remaining_docs = [doc for doc, _ in semantic_results]
|
| 271 |
+
relevance_scores = {id(doc): score for doc, score in semantic_results}
|
| 272 |
+
|
| 273 |
+
while len(selected_docs) < top_k and remaining_docs:
|
| 274 |
+
best_score = -1
|
| 275 |
+
best_doc = None
|
| 276 |
+
best_idx = -1
|
| 277 |
+
|
| 278 |
+
for i, doc in enumerate(remaining_docs):
|
| 279 |
+
doc_text = doc.page_content if hasattr(doc, 'page_content') else str(doc)
|
| 280 |
+
relevance = relevance_scores[id(doc)]
|
| 281 |
+
diversity_penalty = self._calculate_diversity_penalty(doc_text, selected_texts)
|
| 282 |
+
|
| 283 |
+
# MMR分数 = λ * 相关性 - (1-λ) * 多样性惩罚
|
| 284 |
+
mmr_score = (
|
| 285 |
+
self.diversity_lambda * relevance -
|
| 286 |
+
(1 - self.diversity_lambda) * diversity_penalty
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
if mmr_score > best_score:
|
| 290 |
+
best_score = mmr_score
|
| 291 |
+
best_doc = doc
|
| 292 |
+
best_idx = i
|
| 293 |
+
|
| 294 |
+
if best_doc is not None:
|
| 295 |
+
selected_docs.append((best_doc, best_score))
|
| 296 |
+
selected_texts.append(
|
| 297 |
+
best_doc.page_content if hasattr(best_doc, 'page_content') else str(best_doc)
|
| 298 |
+
)
|
| 299 |
+
remaining_docs.pop(best_idx)
|
| 300 |
+
|
| 301 |
+
return selected_docs
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def create_reranker(reranker_type: str, embeddings_model=None, **kwargs) -> DocumentReranker:
|
| 305 |
+
"""工厂函数:创建指定类型的重排器"""
|
| 306 |
+
|
| 307 |
+
if reranker_type.lower() == 'tfidf':
|
| 308 |
+
return TFIDFReranker()
|
| 309 |
+
elif reranker_type.lower() == 'bm25':
|
| 310 |
+
return BM25Reranker(**kwargs)
|
| 311 |
+
elif reranker_type.lower() == 'semantic':
|
| 312 |
+
if embeddings_model is None:
|
| 313 |
+
raise ValueError("SemanticReranker requires embeddings_model")
|
| 314 |
+
return SemanticReranker(embeddings_model)
|
| 315 |
+
elif reranker_type.lower() == 'hybrid':
|
| 316 |
+
if embeddings_model is None:
|
| 317 |
+
raise ValueError("HybridReranker requires embeddings_model")
|
| 318 |
+
return HybridReranker(embeddings_model, **kwargs)
|
| 319 |
+
elif reranker_type.lower() == 'diversity':
|
| 320 |
+
if embeddings_model is None:
|
| 321 |
+
raise ValueError("DiversityReranker requires embeddings_model")
|
| 322 |
+
return DiversityReranker(embeddings_model, **kwargs)
|
| 323 |
+
else:
|
| 324 |
+
raise ValueError(f"Unknown reranker type: {reranker_type}")
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
# 使用示例
|
| 328 |
+
if __name__ == "__main__":
|
| 329 |
+
# 模拟文档
|
| 330 |
+
class MockDoc:
|
| 331 |
+
def __init__(self, content):
|
| 332 |
+
self.page_content = content
|
| 333 |
+
|
| 334 |
+
docs = [
|
| 335 |
+
MockDoc("人工智能是计算机科学的一个分支"),
|
| 336 |
+
MockDoc("机器学习是人工智能的子领域"),
|
| 337 |
+
MockDoc("深度学习使用神经网络"),
|
| 338 |
+
MockDoc("自然语言处理处理文本数据"),
|
| 339 |
+
MockDoc("今天天气很好")
|
| 340 |
+
]
|
| 341 |
+
|
| 342 |
+
query = "什么是人工智能?"
|
| 343 |
+
|
| 344 |
+
# 测试TF-IDF重排
|
| 345 |
+
tfidf_reranker = TFIDFReranker()
|
| 346 |
+
results = tfidf_reranker.rerank(query, docs, top_k=3)
|
| 347 |
+
|
| 348 |
+
print("TF-IDF重排结果:")
|
| 349 |
+
for doc, score in results:
|
| 350 |
+
print(f"分数: {score:.4f} - 内容: {doc.page_content}")
|
routers_and_graders.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
路由器和评分器模块
|
| 3 |
+
包含查询路由、文档相关性评分、答案质量评分和幻觉检测
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from langchain.prompts import PromptTemplate
|
| 7 |
+
from langchain_community.chat_models import ChatOllama
|
| 8 |
+
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
|
| 9 |
+
|
| 10 |
+
from config import LOCAL_LLM
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class QueryRouter:
|
| 14 |
+
"""查询路由器,决定使用向量存储还是网络搜索"""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
|
| 18 |
+
self.prompt = PromptTemplate(
|
| 19 |
+
template="""你是一个专家,负责将用户问题路由到向量存储或网络搜索。
|
| 20 |
+
对于关于LLM智能体、提示工程和对抗性攻击的问题,使用向量存储。
|
| 21 |
+
你不需要严格匹配问题中与这些主题相关的关键词。
|
| 22 |
+
否则,使用网络搜索。根据问题给出二进制选择'web_search'或'vectorstore'。
|
| 23 |
+
返回一个只包含'datasource'键的JSON,不要前言或解释。
|
| 24 |
+
要路由的问题:{question}""",
|
| 25 |
+
input_variables=["question"],
|
| 26 |
+
)
|
| 27 |
+
self.router = self.prompt | self.llm | JsonOutputParser()
|
| 28 |
+
|
| 29 |
+
def route(self, question: str) -> str:
|
| 30 |
+
"""路由问题到相应的数据源"""
|
| 31 |
+
result = self.router.invoke({"question": question})
|
| 32 |
+
return result.get("datasource", "web_search")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class DocumentGrader:
|
| 36 |
+
"""文档相关性评分器"""
|
| 37 |
+
|
| 38 |
+
def __init__(self):
|
| 39 |
+
self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
|
| 40 |
+
self.prompt = PromptTemplate(
|
| 41 |
+
template="""你是一个评分员,评估检索到的文档是否与用户问题相关。
|
| 42 |
+
如果文档包含与用户问题相关的关键词或语义,请给出'yes'分数。
|
| 43 |
+
给出二进制分数'yes'或'no',以表明文档是否与问题相关。
|
| 44 |
+
将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。
|
| 45 |
+
|
| 46 |
+
检索到的文档:
|
| 47 |
+
|
| 48 |
+
{document}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
用户问题:{question}""",
|
| 52 |
+
input_variables=["question", "document"],
|
| 53 |
+
)
|
| 54 |
+
self.grader = self.prompt | self.llm | JsonOutputParser()
|
| 55 |
+
|
| 56 |
+
def grade(self, question: str, document: str) -> str:
|
| 57 |
+
"""评估文档与问题的相关性"""
|
| 58 |
+
result = self.grader.invoke({"question": question, "document": document})
|
| 59 |
+
return result.get("score", "no")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class AnswerGrader:
|
| 63 |
+
"""答案质量评分器"""
|
| 64 |
+
|
| 65 |
+
def __init__(self):
|
| 66 |
+
self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
|
| 67 |
+
self.prompt = PromptTemplate(
|
| 68 |
+
template="""你是一个评分员,评估答案是否有助于解决问题。
|
| 69 |
+
这里是答案:
|
| 70 |
+
\n ------- \n
|
| 71 |
+
{generation}
|
| 72 |
+
\n ------- \n
|
| 73 |
+
这里是问题:{question}
|
| 74 |
+
给出二进制分数'yes'或'no',表示答案是否有助于解决问题。
|
| 75 |
+
将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。""",
|
| 76 |
+
input_variables=["generation", "question"],
|
| 77 |
+
)
|
| 78 |
+
self.grader = self.prompt | self.llm | JsonOutputParser()
|
| 79 |
+
|
| 80 |
+
def grade(self, question: str, generation: str) -> str:
|
| 81 |
+
"""评估答案质量"""
|
| 82 |
+
result = self.grader.invoke({"question": question, "generation": generation})
|
| 83 |
+
return result.get("score", "no")
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class HallucinationGrader:
|
| 87 |
+
"""幻觉检测器"""
|
| 88 |
+
|
| 89 |
+
def __init__(self):
|
| 90 |
+
self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
|
| 91 |
+
self.prompt = PromptTemplate(
|
| 92 |
+
template="""你是一个评分员,评估LLM生成是否基于/支持一组检索到的事实。
|
| 93 |
+
给出二进制分数'yes'或'no'。'yes'意味着答案基于/支持文档。
|
| 94 |
+
将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。
|
| 95 |
+
|
| 96 |
+
检索到的文档:
|
| 97 |
+
|
| 98 |
+
{documents}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
LLM生成:{generation}""",
|
| 102 |
+
input_variables=["generation", "documents"],
|
| 103 |
+
)
|
| 104 |
+
self.grader = self.prompt | self.llm | JsonOutputParser()
|
| 105 |
+
|
| 106 |
+
def grade(self, generation: str, documents) -> str:
|
| 107 |
+
"""检测生成内容是否存在幻觉"""
|
| 108 |
+
result = self.grader.invoke({"generation": generation, "documents": documents})
|
| 109 |
+
return result.get("score", "no")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class QueryRewriter:
|
| 113 |
+
"""查询重写器,优化查询以获得更好的检索结果"""
|
| 114 |
+
|
| 115 |
+
def __init__(self):
|
| 116 |
+
self.llm = ChatOllama(model=LOCAL_LLM, temperature=0)
|
| 117 |
+
self.prompt = PromptTemplate(
|
| 118 |
+
template="""你是一个问题重写器,将输入问题转换为更适合向量存储检索的更好版本。
|
| 119 |
+
查看初始问题并制定一个改进的问题。
|
| 120 |
+
这里是初始问题:\n\n {question}。改进的问题(无前言):\n """,
|
| 121 |
+
input_variables=["question"],
|
| 122 |
+
)
|
| 123 |
+
self.rewriter = self.prompt | self.llm | StrOutputParser()
|
| 124 |
+
|
| 125 |
+
def rewrite(self, question: str) -> str:
|
| 126 |
+
"""重写查询以获得更好的检索效果"""
|
| 127 |
+
print(f"---原始查询: {question}---")
|
| 128 |
+
rewritten_query = self.rewriter.invoke({"question": question})
|
| 129 |
+
print(f"---重写查询: {rewritten_query}---")
|
| 130 |
+
return rewritten_query
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def initialize_graders_and_router():
|
| 134 |
+
"""初始化所有评分器和路由器"""
|
| 135 |
+
query_router = QueryRouter()
|
| 136 |
+
document_grader = DocumentGrader()
|
| 137 |
+
answer_grader = AnswerGrader()
|
| 138 |
+
hallucination_grader = HallucinationGrader()
|
| 139 |
+
query_rewriter = QueryRewriter()
|
| 140 |
+
|
| 141 |
+
return {
|
| 142 |
+
"query_router": query_router,
|
| 143 |
+
"document_grader": document_grader,
|
| 144 |
+
"answer_grader": answer_grader,
|
| 145 |
+
"hallucination_grader": hallucination_grader,
|
| 146 |
+
"query_rewriter": query_rewriter
|
| 147 |
+
}
|
test_reranking.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
重排功能测试脚本
|
| 4 |
+
演示不同重排策略的效果
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
import os
|
| 9 |
+
sys.path.append(os.path.dirname(__file__))
|
| 10 |
+
|
| 11 |
+
from document_processor import DocumentProcessor
|
| 12 |
+
from reranker import *
|
| 13 |
+
from langchain.schema import Document
|
| 14 |
+
import time
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def create_test_documents():
|
| 18 |
+
"""创建测试文档"""
|
| 19 |
+
return [
|
| 20 |
+
Document(
|
| 21 |
+
page_content="人工智能(AI)是计算机科学的一个分支,致力于创建能够执行通常需要人类智能的任务的系统。",
|
| 22 |
+
metadata={"source": "ai_intro.txt", "category": "AI基础"}
|
| 23 |
+
),
|
| 24 |
+
Document(
|
| 25 |
+
page_content="机器学习是人工智能的一个重要子领域,通过算法让计算机从数据中学习模式和规律。",
|
| 26 |
+
metadata={"source": "ml_basics.txt", "category": "机器学习"}
|
| 27 |
+
),
|
| 28 |
+
Document(
|
| 29 |
+
page_content="深度学习是机器学习的一个分支,使用多层神经网络来模拟人脑的学习过程。",
|
| 30 |
+
metadata={"source": "dl_guide.txt", "category": "深度学习"}
|
| 31 |
+
),
|
| 32 |
+
Document(
|
| 33 |
+
page_content="自然语言处理(NLP)是人工智能领域的一个重要分支,专注于使计算机理解和处理人类语言。",
|
| 34 |
+
metadata={"source": "nlp_overview.txt", "category": "自然语言处理"}
|
| 35 |
+
),
|
| 36 |
+
Document(
|
| 37 |
+
page_content="计算机视觉是人工智能的另一个重要领域,使计算机能够识别和理解图像和视频内容。",
|
| 38 |
+
metadata={"source": "cv_intro.txt", "category": "计算机视觉"}
|
| 39 |
+
),
|
| 40 |
+
Document(
|
| 41 |
+
page_content="强化学习是机器学习的一种类型,通过与环境交互来学习最优的行为策略。",
|
| 42 |
+
metadata={"source": "rl_basics.txt", "category": "强化学习"}
|
| 43 |
+
),
|
| 44 |
+
Document(
|
| 45 |
+
page_content="今天的天气非常好,阳光明媚,适合外出游玩和运动。",
|
| 46 |
+
metadata={"source": "weather.txt", "category": "天气"}
|
| 47 |
+
),
|
| 48 |
+
Document(
|
| 49 |
+
page_content="区块链是一种分布式账本技术,具有去中心化、不可篡改等特点。",
|
| 50 |
+
metadata={"source": "blockchain.txt", "category": "区块链"}
|
| 51 |
+
)
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def test_reranker_comparison():
|
| 56 |
+
"""比较不同重排器的效果"""
|
| 57 |
+
print("🔍 重排器效果比较测试")
|
| 58 |
+
print("=" * 60)
|
| 59 |
+
|
| 60 |
+
# 创建测试数据
|
| 61 |
+
query = "什么是人工智能和机器学习?"
|
| 62 |
+
documents = create_test_documents()
|
| 63 |
+
|
| 64 |
+
# 创建一个简单的嵌入模型(用于测试)
|
| 65 |
+
try:
|
| 66 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 67 |
+
embeddings = HuggingFaceEmbeddings(
|
| 68 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
| 69 |
+
model_kwargs={'device': 'cpu'}
|
| 70 |
+
)
|
| 71 |
+
print("✅ 成功加载嵌入模型")
|
| 72 |
+
except Exception as e:
|
| 73 |
+
print(f"❌ 嵌入模型加载失败: {e}")
|
| 74 |
+
print("将使用基础重排器进行测试")
|
| 75 |
+
embeddings = None
|
| 76 |
+
|
| 77 |
+
# 测试不同的重排器
|
| 78 |
+
rerankers = []
|
| 79 |
+
|
| 80 |
+
# TF-IDF重排器
|
| 81 |
+
rerankers.append(("TF-IDF", TFIDFReranker()))
|
| 82 |
+
|
| 83 |
+
# BM25重排器
|
| 84 |
+
rerankers.append(("BM25", BM25Reranker()))
|
| 85 |
+
|
| 86 |
+
if embeddings:
|
| 87 |
+
# 语义重排器
|
| 88 |
+
rerankers.append(("语义相似度", SemanticReranker(embeddings)))
|
| 89 |
+
|
| 90 |
+
# 混合重排器
|
| 91 |
+
rerankers.append(("混合策略", HybridReranker(embeddings)))
|
| 92 |
+
|
| 93 |
+
# 多样性重排器
|
| 94 |
+
rerankers.append(("多样性优化", DiversityReranker(embeddings)))
|
| 95 |
+
|
| 96 |
+
# 执行测试
|
| 97 |
+
for name, reranker in rerankers:
|
| 98 |
+
print(f"\n📊 {name} 重排结果:")
|
| 99 |
+
print("-" * 40)
|
| 100 |
+
|
| 101 |
+
start_time = time.time()
|
| 102 |
+
try:
|
| 103 |
+
results = reranker.rerank(query, documents, top_k=5)
|
| 104 |
+
end_time = time.time()
|
| 105 |
+
|
| 106 |
+
print(f"⏱️ 处理时间: {(end_time - start_time)*1000:.2f}ms")
|
| 107 |
+
|
| 108 |
+
for i, (doc, score) in enumerate(results, 1):
|
| 109 |
+
content = doc.page_content[:80] + "..." if len(doc.page_content) > 80 else doc.page_content
|
| 110 |
+
category = doc.metadata.get('category', '未知')
|
| 111 |
+
print(f"{i}. [分数: {score:.4f}] [{category}] {content}")
|
| 112 |
+
|
| 113 |
+
except Exception as e:
|
| 114 |
+
print(f"❌ 重排失败: {e}")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def test_reranking_with_embeddings():
|
| 118 |
+
"""测试带嵌入的重排功能"""
|
| 119 |
+
print("\n\n🧠 嵌入模型重排测试")
|
| 120 |
+
print("=" * 60)
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
# 创建文档处理器
|
| 124 |
+
processor = DocumentProcessor()
|
| 125 |
+
|
| 126 |
+
# 创建测试文档
|
| 127 |
+
test_docs = create_test_documents()
|
| 128 |
+
|
| 129 |
+
# 测试查询
|
| 130 |
+
queries = [
|
| 131 |
+
"人工智能的定义是什么?",
|
| 132 |
+
"机器学习和深度学习的区别",
|
| 133 |
+
"自然语言处理的应用",
|
| 134 |
+
"今天天气怎么样?"
|
| 135 |
+
]
|
| 136 |
+
|
| 137 |
+
for query in queries:
|
| 138 |
+
print(f"\n🔍 查询: {query}")
|
| 139 |
+
print("-" * 30)
|
| 140 |
+
|
| 141 |
+
if processor.reranker:
|
| 142 |
+
# 使用重排功能
|
| 143 |
+
results = processor.reranker.rerank(query, test_docs, top_k=3)
|
| 144 |
+
|
| 145 |
+
for i, (doc, score) in enumerate(results, 1):
|
| 146 |
+
content = doc.page_content[:60] + "..." if len(doc.page_content) > 60 else doc.page_content
|
| 147 |
+
category = doc.metadata.get('category', '未知')
|
| 148 |
+
print(f"{i}. [分数: {score:.4f}] [{category}] {content}")
|
| 149 |
+
else:
|
| 150 |
+
print("❌ 重排器未初始化")
|
| 151 |
+
|
| 152 |
+
except Exception as e:
|
| 153 |
+
print(f"❌ 测试失败: {e}")
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def test_performance_comparison():
|
| 157 |
+
"""性能对比测试"""
|
| 158 |
+
print("\n\n⚡ 性能对比测试")
|
| 159 |
+
print("=" * 60)
|
| 160 |
+
|
| 161 |
+
documents = create_test_documents() * 10 # 增加文档数量
|
| 162 |
+
query = "人工智能技术的发展趋势"
|
| 163 |
+
|
| 164 |
+
# 测试不同重排器的性能
|
| 165 |
+
rerankers_config = [
|
| 166 |
+
("无重排", None),
|
| 167 |
+
("TF-IDF", TFIDFReranker()),
|
| 168 |
+
("BM25", BM25Reranker())
|
| 169 |
+
]
|
| 170 |
+
|
| 171 |
+
for name, reranker in rerankers_config:
|
| 172 |
+
times = []
|
| 173 |
+
|
| 174 |
+
# 多次测试取平均值
|
| 175 |
+
for _ in range(5):
|
| 176 |
+
start_time = time.time()
|
| 177 |
+
|
| 178 |
+
if reranker:
|
| 179 |
+
results = reranker.rerank(query, documents, top_k=5)
|
| 180 |
+
else:
|
| 181 |
+
# 模拟无重排的情况
|
| 182 |
+
results = documents[:5]
|
| 183 |
+
|
| 184 |
+
end_time = time.time()
|
| 185 |
+
times.append((end_time - start_time) * 1000)
|
| 186 |
+
|
| 187 |
+
avg_time = sum(times) / len(times)
|
| 188 |
+
print(f"{name}: 平均处理时间 {avg_time:.2f}ms (文档数: {len(documents)})")
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def main():
|
| 192 |
+
"""主测试函数"""
|
| 193 |
+
print("🚀 向量重排功能综合测试")
|
| 194 |
+
print("=" * 80)
|
| 195 |
+
|
| 196 |
+
try:
|
| 197 |
+
# 基础重排器比较
|
| 198 |
+
test_reranker_comparison()
|
| 199 |
+
|
| 200 |
+
# 嵌入模型重排测试
|
| 201 |
+
test_reranking_with_embeddings()
|
| 202 |
+
|
| 203 |
+
# 性能对比测试
|
| 204 |
+
test_performance_comparison()
|
| 205 |
+
|
| 206 |
+
print("\n\n✅ 所有测试完成!")
|
| 207 |
+
print("=" * 80)
|
| 208 |
+
|
| 209 |
+
except KeyboardInterrupt:
|
| 210 |
+
print("\n❌ 测试被用户中断")
|
| 211 |
+
except Exception as e:
|
| 212 |
+
print(f"\n❌ 测试过程中发生错误: {e}")
|
| 213 |
+
import traceback
|
| 214 |
+
traceback.print_exc()
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
if __name__ == "__main__":
|
| 218 |
+
main()
|
workflow_nodes.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
工作流节点模块
|
| 3 |
+
包含所有工作流节点函数和状态管理
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import List
|
| 7 |
+
from typing_extensions import TypedDict
|
| 8 |
+
from langchain.schema import Document
|
| 9 |
+
from langchain_community.chat_models import ChatOllama
|
| 10 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 11 |
+
from langchain_community.tools.tavily_search import TavilySearchResults
|
| 12 |
+
from langchain.prompts import PromptTemplate
|
| 13 |
+
|
| 14 |
+
from config import LOCAL_LLM, WEB_SEARCH_RESULTS_COUNT
|
| 15 |
+
from pprint import pprint
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class GraphState(TypedDict):
|
| 19 |
+
"""
|
| 20 |
+
表示图的状态
|
| 21 |
+
|
| 22 |
+
属性:
|
| 23 |
+
question: 问题
|
| 24 |
+
generation: LLM生成
|
| 25 |
+
documents: 文档列表
|
| 26 |
+
"""
|
| 27 |
+
question: str
|
| 28 |
+
generation: str
|
| 29 |
+
documents: List[str]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class WorkflowNodes:
|
| 33 |
+
"""工作流节点类,包含所有节点函数"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, retriever, graders):
|
| 36 |
+
self.retriever = retriever
|
| 37 |
+
self.graders = graders
|
| 38 |
+
|
| 39 |
+
# 设置RAG链 - 使用本地提示模板
|
| 40 |
+
rag_prompt_template = PromptTemplate(
|
| 41 |
+
template="""你是一个问答助手。使用以下检索到的上下文来回答问题。
|
| 42 |
+
如果你不知道答案,就说你不知道。最多使用三句话并保持答案简洁。
|
| 43 |
+
|
| 44 |
+
问题: {question}
|
| 45 |
+
|
| 46 |
+
上下文: {context}
|
| 47 |
+
|
| 48 |
+
答案:""",
|
| 49 |
+
input_variables=["question", "context"]
|
| 50 |
+
)
|
| 51 |
+
llm = ChatOllama(model=LOCAL_LLM, temperature=0)
|
| 52 |
+
self.rag_chain = rag_prompt_template | llm | StrOutputParser()
|
| 53 |
+
|
| 54 |
+
# 设置网络搜索
|
| 55 |
+
self.web_search_tool = TavilySearchResults(k=WEB_SEARCH_RESULTS_COUNT)
|
| 56 |
+
|
| 57 |
+
def retrieve(self, state):
|
| 58 |
+
"""
|
| 59 |
+
检索文档
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
state (dict): 当前图状态
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
state (dict): 添加了documents键的新状态,包含检索到的文档
|
| 66 |
+
"""
|
| 67 |
+
print("---检索---")
|
| 68 |
+
question = state["question"]
|
| 69 |
+
|
| 70 |
+
# 检索
|
| 71 |
+
documents = self.retriever.get_relevant_documents(question)
|
| 72 |
+
return {"documents": documents, "question": question}
|
| 73 |
+
|
| 74 |
+
def generate(self, state):
|
| 75 |
+
"""
|
| 76 |
+
生成答案
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
state (dict): 当前图状态
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
state (dict): 添加了generation键的新状态,包含LLM生成
|
| 83 |
+
"""
|
| 84 |
+
print("---生成---")
|
| 85 |
+
question = state["question"]
|
| 86 |
+
documents = state["documents"]
|
| 87 |
+
|
| 88 |
+
# RAG生成
|
| 89 |
+
generation = self.rag_chain.invoke({"context": documents, "question": question})
|
| 90 |
+
return {"documents": documents, "question": question, "generation": generation}
|
| 91 |
+
|
| 92 |
+
def grade_documents(self, state):
|
| 93 |
+
"""
|
| 94 |
+
确定检索到的文档是否与问题相关
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
state (dict): 当前图状态
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
state (dict): 更新documents键,只包含过滤后的相关文档
|
| 101 |
+
"""
|
| 102 |
+
print("---检查文档与问题的相关性---")
|
| 103 |
+
question = state["question"]
|
| 104 |
+
documents = state["documents"]
|
| 105 |
+
|
| 106 |
+
# 为每个文档评分
|
| 107 |
+
filtered_docs = []
|
| 108 |
+
for d in documents:
|
| 109 |
+
score = self.graders["document_grader"].grade(question, d.page_content)
|
| 110 |
+
grade = score
|
| 111 |
+
if grade == "yes":
|
| 112 |
+
print("---评分:文档相关---")
|
| 113 |
+
filtered_docs.append(d)
|
| 114 |
+
else:
|
| 115 |
+
print("---评分:文档不相关---")
|
| 116 |
+
continue
|
| 117 |
+
return {"documents": filtered_docs, "question": question}
|
| 118 |
+
|
| 119 |
+
def transform_query(self, state):
|
| 120 |
+
"""
|
| 121 |
+
转换查询以产生更好的问题
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
state (dict): 当前图状态
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
state (dict): 用重新表述的问题更新question键
|
| 128 |
+
"""
|
| 129 |
+
print("---转换查询---")
|
| 130 |
+
question = state["question"]
|
| 131 |
+
documents = state["documents"]
|
| 132 |
+
|
| 133 |
+
# 重写问题
|
| 134 |
+
better_question = self.graders["query_rewriter"].rewrite(question)
|
| 135 |
+
return {"documents": documents, "question": better_question}
|
| 136 |
+
|
| 137 |
+
def web_search(self, state):
|
| 138 |
+
"""
|
| 139 |
+
基于重新表述的问题进行网络搜索
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
state (dict): 当前图状态
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
state (dict): 用附加的网络结果更新documents键
|
| 146 |
+
"""
|
| 147 |
+
print("---网络搜索---")
|
| 148 |
+
question = state["question"]
|
| 149 |
+
|
| 150 |
+
# 网络搜索
|
| 151 |
+
docs = self.web_search_tool.invoke({"query": question})
|
| 152 |
+
web_results = "\n".join([d["content"] for d in docs])
|
| 153 |
+
web_results = Document(page_content=web_results)
|
| 154 |
+
|
| 155 |
+
return {"documents": web_results, "question": question}
|
| 156 |
+
|
| 157 |
+
def route_question(self, state):
|
| 158 |
+
"""
|
| 159 |
+
将问题路由到网络搜索或RAG
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
state (dict): 当前图状态
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
str: 要调用的下一个节点
|
| 166 |
+
"""
|
| 167 |
+
print("---路由问题---")
|
| 168 |
+
question = state["question"]
|
| 169 |
+
print(question)
|
| 170 |
+
source = self.graders["query_router"].route(question)
|
| 171 |
+
print(source)
|
| 172 |
+
if source == "web_search":
|
| 173 |
+
print("---将问题路由到网络搜索---")
|
| 174 |
+
return "web_search"
|
| 175 |
+
elif source == "vectorstore":
|
| 176 |
+
print("---将问题路由到RAG---")
|
| 177 |
+
return "vectorstore"
|
| 178 |
+
|
| 179 |
+
def decide_to_generate(self, state):
|
| 180 |
+
"""
|
| 181 |
+
确定是生成答案还是重新生成问题
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
state (dict): 当前图状态
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
str: 要调用的下一个节点的二进制决策
|
| 188 |
+
"""
|
| 189 |
+
print("---评估已评分的文档---")
|
| 190 |
+
filtered_documents = state["documents"]
|
| 191 |
+
|
| 192 |
+
if not filtered_documents:
|
| 193 |
+
# 所有文档都被过滤掉了
|
| 194 |
+
# 我们将重新生成一个新查询
|
| 195 |
+
print("---决策:所有文档都与问题不相关,转换查询---")
|
| 196 |
+
return "transform_query"
|
| 197 |
+
else:
|
| 198 |
+
# 我们有相关文档,所以生成答案
|
| 199 |
+
print("---决策:生成---")
|
| 200 |
+
return "generate"
|
| 201 |
+
|
| 202 |
+
def grade_generation_v_documents_and_question(self, state):
|
| 203 |
+
"""
|
| 204 |
+
确定生成是否基于文档并回答问题
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
state (dict): 当前图状态
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
str: 要调用的下一个节点的决策
|
| 211 |
+
"""
|
| 212 |
+
print("---检查幻觉---")
|
| 213 |
+
question = state["question"]
|
| 214 |
+
documents = state["documents"]
|
| 215 |
+
generation = state["generation"]
|
| 216 |
+
|
| 217 |
+
score = self.graders["hallucination_grader"].grade(generation, documents)
|
| 218 |
+
grade = score
|
| 219 |
+
|
| 220 |
+
# 检查幻觉
|
| 221 |
+
if grade == "yes":
|
| 222 |
+
print("---决策:生成基于文档---")
|
| 223 |
+
# 检查问题回答
|
| 224 |
+
print("---评分生成 vs 问题---")
|
| 225 |
+
score = self.graders["answer_grader"].grade(question, generation)
|
| 226 |
+
grade = score
|
| 227 |
+
if grade == "yes":
|
| 228 |
+
print("---决策:生成解决了问题---")
|
| 229 |
+
return "useful"
|
| 230 |
+
else:
|
| 231 |
+
print("---决策:生成没有解决问题---")
|
| 232 |
+
return "not useful"
|
| 233 |
+
else:
|
| 234 |
+
print("---决策:生成不基于文档,重试---")
|
| 235 |
+
return "not supported"
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def format_docs(docs):
|
| 239 |
+
"""格式化文档用于显示"""
|
| 240 |
+
return "\n\n".join(doc.page_content for doc in docs)
|