Spaces:
Running
Running
test
#1
by
goodmodeler
- opened
- .gitattributes +35 -0
- .gitignore +0 -49
- PROJECT_INFO.md +0 -141
- README.md +9 -101
- calibration/__init__.py +0 -5
- calibration/calibration_head.py +0 -210
- calibration/features.py +0 -173
- calibration/trainer.py +0 -171
- config.yaml +0 -189
- data_processing/__init__.py +0 -4
- data_processing/data_loader.py +0 -29
- data_processing/preprocessor.py +0 -112
- eval/__init__.py +0 -6
- eval/eval_attr.py +0 -275
- eval/eval_calib.py +0 -269
- eval/eval_qa.py +0 -137
- eval/eval_system.py +0 -297
- exp_pipeline/pipeline.py +0 -56
- generator/__init__.py +0 -5
- generator/prompt_templates.py +0 -113
- generator/safe_generate.py +0 -170
- generator/vllm_server.py +0 -102
- real_embedding_test.py +0 -269
- requirements.txt +0 -19
- retriever/__init__.py +0 -6
- retriever/embedder.py +0 -49
- retriever/faiss_index.py +0 -131
- retriever/reranker.py +0 -46
- retriever/retriever.py +0 -104
- simple_e2e_test.py +0 -518
- simple_test.py +0 -167
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
DELETED
|
@@ -1,49 +0,0 @@
|
|
| 1 |
-
# Python
|
| 2 |
-
__pycache__/
|
| 3 |
-
*.py[cod]
|
| 4 |
-
*$py.class
|
| 5 |
-
*.so
|
| 6 |
-
.Python
|
| 7 |
-
build/
|
| 8 |
-
develop-eggs/
|
| 9 |
-
dist/
|
| 10 |
-
downloads/
|
| 11 |
-
eggs/
|
| 12 |
-
.eggs/
|
| 13 |
-
lib/
|
| 14 |
-
lib64/
|
| 15 |
-
parts/
|
| 16 |
-
sdist/
|
| 17 |
-
var/
|
| 18 |
-
wheels/
|
| 19 |
-
*.egg-info/
|
| 20 |
-
.installed.cfg
|
| 21 |
-
*.egg
|
| 22 |
-
|
| 23 |
-
# Virtual environments
|
| 24 |
-
venv/
|
| 25 |
-
env/
|
| 26 |
-
ENV/
|
| 27 |
-
|
| 28 |
-
# IDE
|
| 29 |
-
.vscode/
|
| 30 |
-
.idea/
|
| 31 |
-
*.swp
|
| 32 |
-
*.swo
|
| 33 |
-
|
| 34 |
-
# OS
|
| 35 |
-
.DS_Store
|
| 36 |
-
Thumbs.db
|
| 37 |
-
|
| 38 |
-
# Project specific
|
| 39 |
-
cache/
|
| 40 |
-
logs/
|
| 41 |
-
results/
|
| 42 |
-
models/
|
| 43 |
-
index/
|
| 44 |
-
data/
|
| 45 |
-
*.log
|
| 46 |
-
|
| 47 |
-
# Temporary files
|
| 48 |
-
*.tmp
|
| 49 |
-
*.temp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PROJECT_INFO.md
DELETED
|
@@ -1,141 +0,0 @@
|
|
| 1 |
-
# SafeRAG 项目信息
|
| 2 |
-
|
| 3 |
-
## 📁 项目结构
|
| 4 |
-
|
| 5 |
-
```
|
| 6 |
-
safe_rag/
|
| 7 |
-
├── app.py # Gradio 演示应用
|
| 8 |
-
├── requirements.txt # Python 依赖
|
| 9 |
-
├── config.yaml # 配置文件
|
| 10 |
-
├── README.md # 项目说明(HF Spaces 配置)
|
| 11 |
-
├── simple_e2e_test.py # 端到端测试
|
| 12 |
-
├── simple_test.py # 基本功能测试
|
| 13 |
-
├── data_processing/ # 数据处理模块
|
| 14 |
-
│ ├── __init__.py
|
| 15 |
-
│ ├── data_loader.py # 数据加载器
|
| 16 |
-
│ └── preprocessor.py # 文本预处理器
|
| 17 |
-
├── retriever/ # 检索模块
|
| 18 |
-
│ ├── __init__.py
|
| 19 |
-
│ ├── embedder.py # 嵌入生成器
|
| 20 |
-
│ ├── faiss_index.py # FAISS 索引
|
| 21 |
-
│ ├── retriever.py # 检索器
|
| 22 |
-
│ └── reranker.py # 重排序器
|
| 23 |
-
├── generator/ # 生成模块
|
| 24 |
-
│ ├── __init__.py
|
| 25 |
-
│ ├── vllm_server.py # vLLM 服务器
|
| 26 |
-
│ ├── prompt_templates.py # 提示模板
|
| 27 |
-
│ └── safe_generate.py # 安全生成器
|
| 28 |
-
├── calibration/ # 校准模块
|
| 29 |
-
│ ├── __init__.py
|
| 30 |
-
│ ├── features.py # 特征提取
|
| 31 |
-
│ ├── calibration_head.py # 校准头
|
| 32 |
-
│ └── trainer.py # 训练器
|
| 33 |
-
└── eval/ # 评估模块
|
| 34 |
-
├── __init__.py
|
| 35 |
-
├── eval_qa.py # QA 评估
|
| 36 |
-
├── eval_attr.py # 归因评估
|
| 37 |
-
├── eval_calib.py # 校准评估
|
| 38 |
-
└── eval_system.py # 系统评估
|
| 39 |
-
```
|
| 40 |
-
|
| 41 |
-
## 🚀 核心功能
|
| 42 |
-
|
| 43 |
-
### 1. 数据处理 (`data_processing/`)
|
| 44 |
-
- **DataLoader**: 加载 HF Datasets(HotpotQA, TriviaQA, Wikipedia)
|
| 45 |
-
- **Preprocessor**: 文本清理、句子分割、词元化
|
| 46 |
-
|
| 47 |
-
### 2. 检索系统 (`retriever/`)
|
| 48 |
-
- **Embedder**: 使用 BGE/E5 生成嵌入向量
|
| 49 |
-
- **FAISSIndex**: 构建和搜索 FAISS 索引
|
| 50 |
-
- **Retriever**: 批量检索相关文档
|
| 51 |
-
- **Reranker**: 重排序提升检索质量
|
| 52 |
-
|
| 53 |
-
### 3. 生成系统 (`generator/`)
|
| 54 |
-
- **VLLMServer**: vLLM 推理服务器
|
| 55 |
-
- **SafeGenerator**: 风险感知的答案生成
|
| 56 |
-
- **PromptTemplates**: 提示模板管理
|
| 57 |
-
|
| 58 |
-
### 4. 风险校准 (`calibration/`)
|
| 59 |
-
- **RiskFeatureExtractor**: 提取 16 维风险特征
|
| 60 |
-
- **CalibrationHead**: LogReg/MLP 校准头
|
| 61 |
-
- **Trainer**: 校准头训练
|
| 62 |
-
|
| 63 |
-
### 5. 评估系统 (`eval/`)
|
| 64 |
-
- **QAEvaluator**: EM/F1 评估
|
| 65 |
-
- **AttributionEvaluator**: 引用归因评估
|
| 66 |
-
- **CalibrationEvaluator**: 校准质量评估
|
| 67 |
-
- **SystemEvaluator**: 系统性能评估
|
| 68 |
-
|
| 69 |
-
## 🎯 风险校准策略
|
| 70 |
-
|
| 71 |
-
### 风险特征 (16维)
|
| 72 |
-
1. **检索统计**: 相似度分数、方差、多样性
|
| 73 |
-
2. **覆盖特征**: Q&A 间的 token/实体重叠
|
| 74 |
-
3. **一致性特征**: 段落间语义相似度
|
| 75 |
-
4. **多样性特征**: 主题方差、段落多样性
|
| 76 |
-
|
| 77 |
-
### 自适应策略
|
| 78 |
-
- **低风险 (r < 0.3)**: 正常生成
|
| 79 |
-
- **中风险 (0.3 ≤ r < 0.7)**: 保守生成 + 强制引用
|
| 80 |
-
- **高风险 (r ≥ 0.7)**: 非常保守或拒绝回答
|
| 81 |
-
|
| 82 |
-
## 📊 性能目标
|
| 83 |
-
|
| 84 |
-
- **QA 准确率**: 相比 vanilla RAG 的 EM/F1 提升
|
| 85 |
-
- **归因质量**: 引用精确率/召回率提升 8-12pt
|
| 86 |
-
- **校准质量**: ECE 降低 30-40%
|
| 87 |
-
- **系统吞吐**: vLLM 带来 2-3.5x 提升
|
| 88 |
-
|
| 89 |
-
## 🧪 测试验证
|
| 90 |
-
|
| 91 |
-
### 端到端测试 (`simple_e2e_test.py`)
|
| 92 |
-
- ✅ 8/8 测试通过
|
| 93 |
-
- ✅ 完整 RAG 流程验证
|
| 94 |
-
- ✅ 所有核心功能正常
|
| 95 |
-
|
| 96 |
-
### 基本测试 (`simple_test.py`)
|
| 97 |
-
- ✅ 模块导入测试
|
| 98 |
-
- ✅ 基本功能验证
|
| 99 |
-
- ✅ 配置检查
|
| 100 |
-
|
| 101 |
-
## 🚀 部署到 Hugging Face Spaces
|
| 102 |
-
|
| 103 |
-
### 1. 上传文件
|
| 104 |
-
- 将整个 `safe_rag` 目录上传到 HF Spaces
|
| 105 |
-
- 确保 `app.py` 在根目录
|
| 106 |
-
|
| 107 |
-
### 2. 配置 Spaces
|
| 108 |
-
- SDK: Gradio
|
| 109 |
-
- Hardware: GPU (推荐 A10G 或 A100)
|
| 110 |
-
- Environment: Python 3.8+
|
| 111 |
-
|
| 112 |
-
### 3. 自动部署
|
| 113 |
-
- HF Spaces 会自动安装依赖
|
| 114 |
-
- 自动启动 `app.py`
|
| 115 |
-
- 提供公共访问链接
|
| 116 |
-
|
| 117 |
-
## 📝 使用说明
|
| 118 |
-
|
| 119 |
-
### 本地运行
|
| 120 |
-
```bash
|
| 121 |
-
# 安装依赖
|
| 122 |
-
pip install -r requirements.txt
|
| 123 |
-
|
| 124 |
-
# 运行测试
|
| 125 |
-
python3 simple_e2e_test.py
|
| 126 |
-
|
| 127 |
-
# 启动演示
|
| 128 |
-
python3 app.py
|
| 129 |
-
```
|
| 130 |
-
|
| 131 |
-
### 在线演示
|
| 132 |
-
访问 Hugging Face Spaces 链接,体验交互式 RAG 系统。
|
| 133 |
-
|
| 134 |
-
## 🎉 项目状态
|
| 135 |
-
|
| 136 |
-
✅ **完成**: 所有核心模块实现
|
| 137 |
-
✅ **测试**: 端到端测试通过
|
| 138 |
-
✅ **简化**: 移除不必要的文件
|
| 139 |
-
✅ **就绪**: 可部署到 HF Spaces
|
| 140 |
-
|
| 141 |
-
SafeRAG 项目已经准备好部署和使用了!
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -1,108 +1,16 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
A production-ready Retrieval-Augmented Generation (RAG) system with risk calibration, built on the Hugging Face ecosystem.
|
| 16 |
-
|
| 17 |
-
## 🚀 Key Features
|
| 18 |
-
|
| 19 |
-
- **Risk Calibration**: Multi-layer risk assessment with adaptive strategies
|
| 20 |
-
- **High Performance**: Optimized for 2-3.5x throughput improvement
|
| 21 |
-
- **Hugging Face Native**: Built on HF Datasets, Models, and Spaces
|
| 22 |
-
- **Production Ready**: Complete pipeline with error handling and monitoring
|
| 23 |
-
|
| 24 |
-
## 🏗️ Architecture
|
| 25 |
-
|
| 26 |
-
```
|
| 27 |
-
HF Datasets → Embedding (BGE/E5) → FAISS Index
|
| 28 |
-
Query → Batched Retrieval → Evidence Selector → Generator (vLLM + gpt-oss-20b)
|
| 29 |
-
→ Risk Calibration → Adaptive Strategy → Output (Answer + Citations + Risk Score)
|
| 30 |
-
```
|
| 31 |
-
|
| 32 |
-
## 📊 Performance Targets
|
| 33 |
-
|
| 34 |
-
- **QA Accuracy**: EM/F1 improvements over vanilla RAG
|
| 35 |
-
- **Attribution**: +8-12pt improvement in citation precision/recall
|
| 36 |
-
- **Calibration**: 30-40% reduction in ECE (Expected Calibration Error)
|
| 37 |
-
- **Throughput**: 2-3.5x improvement with vLLM
|
| 38 |
-
|
| 39 |
-
## 🛠️ Quick Start
|
| 40 |
-
|
| 41 |
-
### Run Tests
|
| 42 |
-
```bash
|
| 43 |
-
python3 simple_e2e_test.py
|
| 44 |
-
```
|
| 45 |
-
|
| 46 |
-
### Start Demo
|
| 47 |
-
```bash
|
| 48 |
-
python3 app.py
|
| 49 |
-
```
|
| 50 |
-
|
| 51 |
-
## 📈 Evaluation
|
| 52 |
-
|
| 53 |
-
The system has been tested with comprehensive end-to-end tests:
|
| 54 |
-
|
| 55 |
-
- ✅ Text processing and sentence extraction
|
| 56 |
-
- ✅ Embedding creation and similarity calculation
|
| 57 |
-
- ✅ Passage retrieval and reranking
|
| 58 |
-
- ✅ Risk feature extraction and prediction
|
| 59 |
-
- ✅ Risk-aware answer generation
|
| 60 |
-
- ✅ Evaluation metrics (EM, F1, ROUGE)
|
| 61 |
-
- ✅ Complete end-to-end RAG pipeline
|
| 62 |
-
|
| 63 |
-
## 🔧 Configuration
|
| 64 |
-
|
| 65 |
-
Key parameters in `config.yaml`:
|
| 66 |
-
|
| 67 |
-
- **Risk Thresholds**: τ₁ = 0.3, τ₂ = 0.7
|
| 68 |
-
- **Retrieval**: k = 20, rerank_k = 10
|
| 69 |
-
- **Generation**: max_tokens = 512, temperature = 0.7
|
| 70 |
-
- **Calibration**: 16 features, logistic regression
|
| 71 |
-
|
| 72 |
-
## 🎯 Risk Calibration
|
| 73 |
-
|
| 74 |
-
### Risk Features (16-dimensional)
|
| 75 |
-
1. **Retrieval Statistics**: Similarity scores, variance, diversity
|
| 76 |
-
2. **Coverage Features**: Token/entity overlap between Q&A
|
| 77 |
-
3. **Consistency Features**: Semantic similarity between passages
|
| 78 |
-
4. **Diversity Features**: Topic variance, passage diversity
|
| 79 |
-
|
| 80 |
-
### Adaptive Strategies
|
| 81 |
-
- **Low Risk (r < τ₁)**: Normal generation
|
| 82 |
-
- **Medium Risk (τ₁ ≤ r < τ₂)**: Conservative generation + citations
|
| 83 |
-
- **High Risk (r ≥ τ₂)**: Very conservative or refuse
|
| 84 |
-
|
| 85 |
-
## 📚 Datasets
|
| 86 |
-
|
| 87 |
-
- **HotpotQA**: Multi-hop reasoning with supporting facts
|
| 88 |
-
- **TriviaQA**: Open-domain QA for general knowledge
|
| 89 |
-
- **Wikipedia**: Knowledge base via HF Datasets
|
| 90 |
-
|
| 91 |
-
## 📄 Citation
|
| 92 |
-
|
| 93 |
-
```bibtex
|
| 94 |
-
@article{safrag2024,
|
| 95 |
-
title={SafeRAG: High-Performance Calibrated RAG with Risk Assessment},
|
| 96 |
-
author={Your Name},
|
| 97 |
-
journal={arXiv preprint},
|
| 98 |
-
year={2024}
|
| 99 |
-
}
|
| 100 |
-
```
|
| 101 |
-
|
| 102 |
-
## 📝 License
|
| 103 |
-
|
| 104 |
-
Apache 2.0 License - see LICENSE file for details.
|
| 105 |
-
|
| 106 |
-
---
|
| 107 |
-
|
| 108 |
-
**SafeRAG**: A production-ready RAG system with risk calibration, built on Hugging Face ecosystem.
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Safe Rag
|
| 3 |
+
emoji: 💬
|
| 4 |
+
colorFrom: yellow
|
| 5 |
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.42.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
hf_oauth: true
|
| 11 |
+
hf_oauth_scopes:
|
| 12 |
+
- inference-api
|
| 13 |
+
short_description: A High-Performance and Risk-Calibrated RAG system
|
| 14 |
---
|
| 15 |
|
| 16 |
+
An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
calibration/__init__.py
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
from .features import RiskFeatureExtractor
|
| 2 |
-
from .calibration_head import CalibrationHead
|
| 3 |
-
from .trainer import CalibrationTrainer
|
| 4 |
-
|
| 5 |
-
__all__ = ['RiskFeatureExtractor', 'CalibrationHead', 'CalibrationTrainer']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
calibration/calibration_head.py
DELETED
|
@@ -1,210 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import numpy as np
|
| 4 |
-
from sklearn.linear_model import LogisticRegression
|
| 5 |
-
from sklearn.ensemble import RandomForestClassifier
|
| 6 |
-
from sklearn.metrics import accuracy_score, roc_auc_score
|
| 7 |
-
from typing import Dict, Any, List, Tuple
|
| 8 |
-
import logging
|
| 9 |
-
import joblib
|
| 10 |
-
import os
|
| 11 |
-
|
| 12 |
-
logger = logging.getLogger(__name__)
|
| 13 |
-
|
| 14 |
-
class CalibrationHead:
|
| 15 |
-
def __init__(self, model_type: str = "logistic", input_dim: int = 16):
|
| 16 |
-
self.model_type = model_type
|
| 17 |
-
self.input_dim = input_dim
|
| 18 |
-
self.model = None
|
| 19 |
-
self.is_trained = False
|
| 20 |
-
|
| 21 |
-
def _create_model(self):
|
| 22 |
-
"""Create the calibration model"""
|
| 23 |
-
if self.model_type == "logistic":
|
| 24 |
-
self.model = LogisticRegression(
|
| 25 |
-
random_state=42,
|
| 26 |
-
max_iter=1000,
|
| 27 |
-
class_weight='balanced'
|
| 28 |
-
)
|
| 29 |
-
elif self.model_type == "random_forest":
|
| 30 |
-
self.model = RandomForestClassifier(
|
| 31 |
-
n_estimators=100,
|
| 32 |
-
random_state=42,
|
| 33 |
-
class_weight='balanced'
|
| 34 |
-
)
|
| 35 |
-
elif self.model_type == "mlp":
|
| 36 |
-
self.model = MLPCalibrationHead(self.input_dim)
|
| 37 |
-
else:
|
| 38 |
-
raise ValueError(f"Unknown model type: {self.model_type}")
|
| 39 |
-
|
| 40 |
-
def train(self, X: np.ndarray, y: np.ndarray) -> Dict[str, float]:
|
| 41 |
-
"""Train the calibration model"""
|
| 42 |
-
if self.model is None:
|
| 43 |
-
self._create_model()
|
| 44 |
-
|
| 45 |
-
if self.model_type in ["logistic", "random_forest"]:
|
| 46 |
-
# Sklearn models
|
| 47 |
-
self.model.fit(X, y)
|
| 48 |
-
|
| 49 |
-
# Get predictions and metrics
|
| 50 |
-
y_pred = self.model.predict(X)
|
| 51 |
-
y_proba = self.model.predict_proba(X)[:, 1] if hasattr(self.model, 'predict_proba') else y_pred
|
| 52 |
-
|
| 53 |
-
metrics = {
|
| 54 |
-
'accuracy': accuracy_score(y, y_pred),
|
| 55 |
-
'auc': roc_auc_score(y, y_proba) if len(np.unique(y)) > 1 else 0.0
|
| 56 |
-
}
|
| 57 |
-
else:
|
| 58 |
-
# PyTorch models
|
| 59 |
-
metrics = self._train_pytorch_model(X, y)
|
| 60 |
-
|
| 61 |
-
self.is_trained = True
|
| 62 |
-
logger.info(f"Trained {self.model_type} model with metrics: {metrics}")
|
| 63 |
-
return metrics
|
| 64 |
-
|
| 65 |
-
def predict_risk(self, features: Dict[str, Any]) -> float:
|
| 66 |
-
"""Predict risk score from features"""
|
| 67 |
-
if not self.is_trained:
|
| 68 |
-
logger.warning("Model not trained, returning default risk score")
|
| 69 |
-
return 0.5
|
| 70 |
-
|
| 71 |
-
# Convert features to array
|
| 72 |
-
X = self._features_to_array(features)
|
| 73 |
-
|
| 74 |
-
if self.model_type in ["logistic", "random_forest"]:
|
| 75 |
-
if hasattr(self.model, 'predict_proba'):
|
| 76 |
-
risk_score = self.model.predict_proba(X.reshape(1, -1))[0, 1]
|
| 77 |
-
else:
|
| 78 |
-
risk_score = float(self.model.predict(X.reshape(1, -1))[0])
|
| 79 |
-
else:
|
| 80 |
-
# PyTorch models
|
| 81 |
-
with torch.no_grad():
|
| 82 |
-
X_tensor = torch.FloatTensor(X.reshape(1, -1))
|
| 83 |
-
risk_score = torch.sigmoid(self.model(X_tensor)).item()
|
| 84 |
-
|
| 85 |
-
return float(risk_score)
|
| 86 |
-
|
| 87 |
-
def predict_batch(self, features_list: List[Dict[str, Any]]) -> List[float]:
|
| 88 |
-
"""Predict risk scores for multiple feature sets"""
|
| 89 |
-
if not features_list:
|
| 90 |
-
return []
|
| 91 |
-
|
| 92 |
-
# Convert all features to arrays
|
| 93 |
-
X = np.array([self._features_to_array(f) for f in features_list])
|
| 94 |
-
|
| 95 |
-
if self.model_type in ["logistic", "random_forest"]:
|
| 96 |
-
if hasattr(self.model, 'predict_proba'):
|
| 97 |
-
risk_scores = self.model.predict_proba(X)[:, 1]
|
| 98 |
-
else:
|
| 99 |
-
risk_scores = self.model.predict(X)
|
| 100 |
-
else:
|
| 101 |
-
# PyTorch models
|
| 102 |
-
with torch.no_grad():
|
| 103 |
-
X_tensor = torch.FloatTensor(X)
|
| 104 |
-
risk_scores = torch.sigmoid(self.model(X_tensor)).numpy()
|
| 105 |
-
|
| 106 |
-
return risk_scores.tolist()
|
| 107 |
-
|
| 108 |
-
def _features_to_array(self, features: Dict[str, Any]) -> np.ndarray:
|
| 109 |
-
"""Convert features dictionary to numpy array"""
|
| 110 |
-
# Define feature order (must match training)
|
| 111 |
-
feature_order = [
|
| 112 |
-
'num_passages', 'avg_similarity', 'std_similarity', 'max_similarity',
|
| 113 |
-
'min_similarity', 'score_variance', 'avg_token_overlap', 'max_token_overlap',
|
| 114 |
-
'avg_entity_overlap', 'max_entity_overlap', 'passage_consistency',
|
| 115 |
-
'passage_consistency_std', 'min_passage_similarity', 'diversity',
|
| 116 |
-
'topic_variance'
|
| 117 |
-
]
|
| 118 |
-
|
| 119 |
-
# Extract features in order
|
| 120 |
-
feature_array = []
|
| 121 |
-
for feature_name in feature_order:
|
| 122 |
-
value = features.get(feature_name, 0.0)
|
| 123 |
-
feature_array.append(float(value))
|
| 124 |
-
|
| 125 |
-
return np.array(feature_array)
|
| 126 |
-
|
| 127 |
-
def _train_pytorch_model(self, X: np.ndarray, y: np.ndarray) -> Dict[str, float]:
|
| 128 |
-
"""Train PyTorch model"""
|
| 129 |
-
# Convert to tensors
|
| 130 |
-
X_tensor = torch.FloatTensor(X)
|
| 131 |
-
y_tensor = torch.FloatTensor(y)
|
| 132 |
-
|
| 133 |
-
# Training setup
|
| 134 |
-
optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
|
| 135 |
-
criterion = nn.BCEWithLogitsLoss()
|
| 136 |
-
|
| 137 |
-
# Training loop
|
| 138 |
-
self.model.train()
|
| 139 |
-
for epoch in range(100):
|
| 140 |
-
optimizer.zero_grad()
|
| 141 |
-
outputs = self.model(X_tensor)
|
| 142 |
-
loss = criterion(outputs.squeeze(), y_tensor)
|
| 143 |
-
loss.backward()
|
| 144 |
-
optimizer.step()
|
| 145 |
-
|
| 146 |
-
# Evaluation
|
| 147 |
-
self.model.eval()
|
| 148 |
-
with torch.no_grad():
|
| 149 |
-
outputs = self.model(X_tensor)
|
| 150 |
-
predictions = torch.sigmoid(outputs).squeeze().numpy()
|
| 151 |
-
binary_preds = (predictions > 0.5).astype(int)
|
| 152 |
-
|
| 153 |
-
metrics = {
|
| 154 |
-
'accuracy': accuracy_score(y, binary_preds),
|
| 155 |
-
'auc': roc_auc_score(y, predictions) if len(np.unique(y)) > 1 else 0.0
|
| 156 |
-
}
|
| 157 |
-
|
| 158 |
-
return metrics
|
| 159 |
-
|
| 160 |
-
def save(self, path: str) -> None:
|
| 161 |
-
"""Save the trained model"""
|
| 162 |
-
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 163 |
-
|
| 164 |
-
if self.model_type in ["logistic", "random_forest"]:
|
| 165 |
-
joblib.dump(self.model, f"{path}.joblib")
|
| 166 |
-
else:
|
| 167 |
-
torch.save(self.model.state_dict(), f"{path}.pth")
|
| 168 |
-
|
| 169 |
-
# Save metadata
|
| 170 |
-
metadata = {
|
| 171 |
-
'model_type': self.model_type,
|
| 172 |
-
'input_dim': self.input_dim,
|
| 173 |
-
'is_trained': self.is_trained
|
| 174 |
-
}
|
| 175 |
-
joblib.dump(metadata, f"{path}_metadata.joblib")
|
| 176 |
-
|
| 177 |
-
logger.info(f"Saved model to {path}")
|
| 178 |
-
|
| 179 |
-
def load(self, path: str) -> None:
|
| 180 |
-
"""Load a trained model"""
|
| 181 |
-
# Load metadata
|
| 182 |
-
metadata = joblib.load(f"{path}_metadata.joblib")
|
| 183 |
-
self.model_type = metadata['model_type']
|
| 184 |
-
self.input_dim = metadata['input_dim']
|
| 185 |
-
self.is_trained = metadata['is_trained']
|
| 186 |
-
|
| 187 |
-
# Load model
|
| 188 |
-
if self.model_type in ["logistic", "random_forest"]:
|
| 189 |
-
self.model = joblib.load(f"{path}.joblib")
|
| 190 |
-
else:
|
| 191 |
-
self.model = MLPCalibrationHead(self.input_dim)
|
| 192 |
-
self.model.load_state_dict(torch.load(f"{path}.pth"))
|
| 193 |
-
|
| 194 |
-
logger.info(f"Loaded model from {path}")
|
| 195 |
-
|
| 196 |
-
class MLPCalibrationHead(nn.Module):
|
| 197 |
-
def __init__(self, input_dim: int, hidden_dim: int = 64):
|
| 198 |
-
super().__init__()
|
| 199 |
-
self.layers = nn.Sequential(
|
| 200 |
-
nn.Linear(input_dim, hidden_dim),
|
| 201 |
-
nn.ReLU(),
|
| 202 |
-
nn.Dropout(0.2),
|
| 203 |
-
nn.Linear(hidden_dim, hidden_dim // 2),
|
| 204 |
-
nn.ReLU(),
|
| 205 |
-
nn.Dropout(0.2),
|
| 206 |
-
nn.Linear(hidden_dim // 2, 1)
|
| 207 |
-
)
|
| 208 |
-
|
| 209 |
-
def forward(self, x):
|
| 210 |
-
return self.layers(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
calibration/features.py
DELETED
|
@@ -1,173 +0,0 @@
|
|
| 1 |
-
from typing import List, Dict, Any
|
| 2 |
-
import numpy as np
|
| 3 |
-
from sentence_transformers import SentenceTransformer
|
| 4 |
-
import logging
|
| 5 |
-
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 6 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
| 7 |
-
import re
|
| 8 |
-
|
| 9 |
-
logger = logging.getLogger(__name__)
|
| 10 |
-
|
| 11 |
-
class RiskFeatureExtractor:
|
| 12 |
-
def __init__(self, embedding_model: str = "BAAI/bge-large-en-v1.5"):
|
| 13 |
-
self.embedding_model = SentenceTransformer(embedding_model)
|
| 14 |
-
self.tfidf_vectorizer = TfidfVectorizer(max_features=1000, stop_words='english')
|
| 15 |
-
|
| 16 |
-
def extract_features(self, question: str, retrieved_passages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 17 |
-
"""Extract risk assessment features"""
|
| 18 |
-
if not retrieved_passages:
|
| 19 |
-
return self._get_empty_features()
|
| 20 |
-
|
| 21 |
-
features = {}
|
| 22 |
-
|
| 23 |
-
# Retrieval statistics
|
| 24 |
-
features.update(self._extract_retrieval_stats(retrieved_passages))
|
| 25 |
-
|
| 26 |
-
# Coverage features
|
| 27 |
-
features.update(self._extract_coverage_features(question, retrieved_passages))
|
| 28 |
-
|
| 29 |
-
# Consistency features
|
| 30 |
-
features.update(self._extract_consistency_features(question, retrieved_passages))
|
| 31 |
-
|
| 32 |
-
# Diversity features
|
| 33 |
-
features.update(self._extract_diversity_features(retrieved_passages))
|
| 34 |
-
|
| 35 |
-
return features
|
| 36 |
-
|
| 37 |
-
def _extract_retrieval_stats(self, passages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 38 |
-
"""Extract retrieval statistics"""
|
| 39 |
-
if not passages:
|
| 40 |
-
return {}
|
| 41 |
-
|
| 42 |
-
scores = [p.get('score', 0.0) for p in passages]
|
| 43 |
-
|
| 44 |
-
return {
|
| 45 |
-
'num_passages': len(passages),
|
| 46 |
-
'avg_similarity': np.mean(scores),
|
| 47 |
-
'std_similarity': np.std(scores),
|
| 48 |
-
'max_similarity': np.max(scores),
|
| 49 |
-
'min_similarity': np.min(scores),
|
| 50 |
-
'score_variance': np.var(scores)
|
| 51 |
-
}
|
| 52 |
-
|
| 53 |
-
def _extract_coverage_features(self, question: str, passages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 54 |
-
"""Extract coverage features between question and passages"""
|
| 55 |
-
if not passages:
|
| 56 |
-
return {}
|
| 57 |
-
|
| 58 |
-
# Token overlap
|
| 59 |
-
question_tokens = set(question.lower().split())
|
| 60 |
-
passage_texts = [p.get('text', '') for p in passages]
|
| 61 |
-
|
| 62 |
-
overlaps = []
|
| 63 |
-
for passage_text in passage_texts:
|
| 64 |
-
passage_tokens = set(passage_text.lower().split())
|
| 65 |
-
overlap = len(question_tokens.intersection(passage_tokens))
|
| 66 |
-
overlaps.append(overlap / len(question_tokens) if question_tokens else 0)
|
| 67 |
-
|
| 68 |
-
# Entity overlap (simplified)
|
| 69 |
-
question_entities = self._extract_entities(question)
|
| 70 |
-
entity_overlaps = []
|
| 71 |
-
|
| 72 |
-
for passage_text in passage_texts:
|
| 73 |
-
passage_entities = self._extract_entities(passage_text)
|
| 74 |
-
overlap = len(question_entities.intersection(passage_entities))
|
| 75 |
-
entity_overlaps.append(overlap / len(question_entities) if question_entities else 0)
|
| 76 |
-
|
| 77 |
-
return {
|
| 78 |
-
'avg_token_overlap': np.mean(overlaps),
|
| 79 |
-
'max_token_overlap': np.max(overlaps),
|
| 80 |
-
'avg_entity_overlap': np.mean(entity_overlaps),
|
| 81 |
-
'max_entity_overlap': np.max(entity_overlaps)
|
| 82 |
-
}
|
| 83 |
-
|
| 84 |
-
def _extract_consistency_features(self, question: str, passages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 85 |
-
"""Extract consistency features between passages"""
|
| 86 |
-
if len(passages) < 2:
|
| 87 |
-
return {'passage_consistency': 1.0}
|
| 88 |
-
|
| 89 |
-
# Semantic similarity between passages
|
| 90 |
-
passage_texts = [p.get('text', '') for p in passages]
|
| 91 |
-
embeddings = self.embedding_model.encode(passage_texts)
|
| 92 |
-
|
| 93 |
-
# Compute pairwise similarities
|
| 94 |
-
similarities = cosine_similarity(embeddings)
|
| 95 |
-
|
| 96 |
-
# Get upper triangle (excluding diagonal)
|
| 97 |
-
upper_triangle = similarities[np.triu_indices_from(similarities, k=1)]
|
| 98 |
-
|
| 99 |
-
return {
|
| 100 |
-
'passage_consistency': np.mean(upper_triangle),
|
| 101 |
-
'passage_consistency_std': np.std(upper_triangle),
|
| 102 |
-
'min_passage_similarity': np.min(upper_triangle)
|
| 103 |
-
}
|
| 104 |
-
|
| 105 |
-
def _extract_diversity_features(self, passages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 106 |
-
"""Extract diversity features"""
|
| 107 |
-
if len(passages) < 2:
|
| 108 |
-
return {'diversity': 1.0}
|
| 109 |
-
|
| 110 |
-
# Topic diversity using TF-IDF
|
| 111 |
-
passage_texts = [p.get('text', '') for p in passages]
|
| 112 |
-
|
| 113 |
-
try:
|
| 114 |
-
tfidf_matrix = self.tfidf_vectorizer.fit_transform(passage_texts)
|
| 115 |
-
similarities = cosine_similarity(tfidf_matrix)
|
| 116 |
-
|
| 117 |
-
# Diversity as 1 - average similarity
|
| 118 |
-
upper_triangle = similarities[np.triu_indices_from(similarities, k=1)]
|
| 119 |
-
diversity = 1.0 - np.mean(upper_triangle)
|
| 120 |
-
|
| 121 |
-
return {
|
| 122 |
-
'diversity': diversity,
|
| 123 |
-
'topic_variance': np.var(upper_triangle)
|
| 124 |
-
}
|
| 125 |
-
except:
|
| 126 |
-
return {'diversity': 0.5, 'topic_variance': 0.0}
|
| 127 |
-
|
| 128 |
-
def _extract_entities(self, text: str) -> set:
|
| 129 |
-
"""Extract entities from text (simplified)"""
|
| 130 |
-
# Simple entity extraction - in practice use NER
|
| 131 |
-
# Look for capitalized words and common entity patterns
|
| 132 |
-
entities = set()
|
| 133 |
-
|
| 134 |
-
# Capitalized words (potential entities)
|
| 135 |
-
capitalized = re.findall(r'\b[A-Z][a-z]+\b', text)
|
| 136 |
-
entities.update(capitalized)
|
| 137 |
-
|
| 138 |
-
# Numbers and dates
|
| 139 |
-
numbers = re.findall(r'\b\d+\b', text)
|
| 140 |
-
entities.update(numbers)
|
| 141 |
-
|
| 142 |
-
return entities
|
| 143 |
-
|
| 144 |
-
def _get_empty_features(self) -> Dict[str, Any]:
|
| 145 |
-
"""Return empty features when no passages available"""
|
| 146 |
-
return {
|
| 147 |
-
'num_passages': 0,
|
| 148 |
-
'avg_similarity': 0.0,
|
| 149 |
-
'std_similarity': 0.0,
|
| 150 |
-
'max_similarity': 0.0,
|
| 151 |
-
'min_similarity': 0.0,
|
| 152 |
-
'score_variance': 0.0,
|
| 153 |
-
'avg_token_overlap': 0.0,
|
| 154 |
-
'max_token_overlap': 0.0,
|
| 155 |
-
'avg_entity_overlap': 0.0,
|
| 156 |
-
'max_entity_overlap': 0.0,
|
| 157 |
-
'passage_consistency': 0.0,
|
| 158 |
-
'passage_consistency_std': 0.0,
|
| 159 |
-
'min_passage_similarity': 0.0,
|
| 160 |
-
'diversity': 0.0,
|
| 161 |
-
'topic_variance': 0.0
|
| 162 |
-
}
|
| 163 |
-
|
| 164 |
-
def extract_batch_features(self, questions: List[str],
|
| 165 |
-
passages_list: List[List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
|
| 166 |
-
"""Extract features for multiple question-passage pairs"""
|
| 167 |
-
features_list = []
|
| 168 |
-
|
| 169 |
-
for question, passages in zip(questions, passages_list):
|
| 170 |
-
features = self.extract_features(question, passages)
|
| 171 |
-
features_list.append(features)
|
| 172 |
-
|
| 173 |
-
return features_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
calibration/trainer.py
DELETED
|
@@ -1,171 +0,0 @@
|
|
| 1 |
-
from typing import List, Dict, Any, Tuple
|
| 2 |
-
import numpy as np
|
| 3 |
-
from sklearn.model_selection import train_test_split
|
| 4 |
-
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
|
| 5 |
-
import logging
|
| 6 |
-
from .features import RiskFeatureExtractor
|
| 7 |
-
from .calibration_head import CalibrationHead
|
| 8 |
-
|
| 9 |
-
logger = logging.getLogger(__name__)
|
| 10 |
-
|
| 11 |
-
class CalibrationTrainer:
|
| 12 |
-
def __init__(self, feature_extractor: RiskFeatureExtractor,
|
| 13 |
-
calibration_head: CalibrationHead):
|
| 14 |
-
self.feature_extractor = feature_extractor
|
| 15 |
-
self.calibration_head = calibration_head
|
| 16 |
-
|
| 17 |
-
def prepare_training_data(self, qa_data: List[Dict[str, Any]],
|
| 18 |
-
retrieved_passages_list: List[List[Dict[str, Any]]],
|
| 19 |
-
labels: List[int]) -> Tuple[np.ndarray, np.ndarray]:
|
| 20 |
-
"""Prepare training data from QA samples and retrieved passages"""
|
| 21 |
-
|
| 22 |
-
# Extract features
|
| 23 |
-
features_list = self.feature_extractor.extract_batch_features(
|
| 24 |
-
[item['question'] for item in qa_data],
|
| 25 |
-
retrieved_passages_list
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
# Convert features to arrays
|
| 29 |
-
X = np.array([self.feature_extractor._features_to_array(f) for f in features_list])
|
| 30 |
-
y = np.array(labels)
|
| 31 |
-
|
| 32 |
-
logger.info(f"Prepared training data: {X.shape[0]} samples, {X.shape[1]} features")
|
| 33 |
-
return X, y
|
| 34 |
-
|
| 35 |
-
def train(self, X: np.ndarray, y: np.ndarray,
|
| 36 |
-
test_size: float = 0.2, random_state: int = 42) -> Dict[str, Any]:
|
| 37 |
-
"""Train the calibration model"""
|
| 38 |
-
|
| 39 |
-
# Split data
|
| 40 |
-
X_train, X_test, y_train, y_test = train_test_split(
|
| 41 |
-
X, y, test_size=test_size, random_state=random_state, stratify=y
|
| 42 |
-
)
|
| 43 |
-
|
| 44 |
-
# Train model
|
| 45 |
-
train_metrics = self.calibration_head.train(X_train, y_train)
|
| 46 |
-
|
| 47 |
-
# Evaluate on test set
|
| 48 |
-
test_metrics = self.evaluate(X_test, y_test)
|
| 49 |
-
|
| 50 |
-
# Combine metrics
|
| 51 |
-
all_metrics = {
|
| 52 |
-
'train': train_metrics,
|
| 53 |
-
'test': test_metrics,
|
| 54 |
-
'train_size': len(X_train),
|
| 55 |
-
'test_size': len(X_test)
|
| 56 |
-
}
|
| 57 |
-
|
| 58 |
-
logger.info(f"Training completed. Test metrics: {test_metrics}")
|
| 59 |
-
return all_metrics
|
| 60 |
-
|
| 61 |
-
def evaluate(self, X: np.ndarray, y: np.ndarray) -> Dict[str, float]:
|
| 62 |
-
"""Evaluate the calibration model"""
|
| 63 |
-
if not self.calibration_head.is_trained:
|
| 64 |
-
raise ValueError("Model not trained yet")
|
| 65 |
-
|
| 66 |
-
# Get predictions
|
| 67 |
-
if hasattr(self.calibration_head.model, 'predict_proba'):
|
| 68 |
-
y_proba = self.calibration_head.model.predict_proba(X)[:, 1]
|
| 69 |
-
y_pred = (y_proba > 0.5).astype(int)
|
| 70 |
-
else:
|
| 71 |
-
y_pred = self.calibration_head.model.predict(X)
|
| 72 |
-
y_proba = y_pred
|
| 73 |
-
|
| 74 |
-
# Calculate metrics
|
| 75 |
-
accuracy = accuracy_score(y, y_pred)
|
| 76 |
-
precision, recall, f1, _ = precision_recall_fscore_support(y, y_pred, average='binary')
|
| 77 |
-
|
| 78 |
-
try:
|
| 79 |
-
auc = roc_auc_score(y, y_proba)
|
| 80 |
-
except:
|
| 81 |
-
auc = 0.0
|
| 82 |
-
|
| 83 |
-
return {
|
| 84 |
-
'accuracy': accuracy,
|
| 85 |
-
'precision': precision,
|
| 86 |
-
'recall': recall,
|
| 87 |
-
'f1': f1,
|
| 88 |
-
'auc': auc
|
| 89 |
-
}
|
| 90 |
-
|
| 91 |
-
def create_synthetic_labels(self, qa_data: List[Dict[str, Any]],
|
| 92 |
-
retrieved_passages_list: List[List[Dict[str, Any]]]) -> List[int]:
|
| 93 |
-
"""Create synthetic risk labels for training (placeholder implementation)"""
|
| 94 |
-
labels = []
|
| 95 |
-
|
| 96 |
-
for qa_item, passages in zip(qa_data, retrieved_passages_list):
|
| 97 |
-
# Simple heuristic for risk labeling
|
| 98 |
-
# In practice, this would be based on human annotations or automated evaluation
|
| 99 |
-
|
| 100 |
-
question = qa_item['question']
|
| 101 |
-
answer = qa_item['answer']
|
| 102 |
-
|
| 103 |
-
# Risk factors
|
| 104 |
-
risk_score = 0.0
|
| 105 |
-
|
| 106 |
-
# Low similarity scores = high risk
|
| 107 |
-
if passages:
|
| 108 |
-
avg_similarity = np.mean([p.get('score', 0.0) for p in passages])
|
| 109 |
-
if avg_similarity < 0.3:
|
| 110 |
-
risk_score += 0.3
|
| 111 |
-
|
| 112 |
-
# Few passages = high risk
|
| 113 |
-
if len(passages) < 3:
|
| 114 |
-
risk_score += 0.2
|
| 115 |
-
|
| 116 |
-
# Question complexity (length, question words)
|
| 117 |
-
if len(question.split()) > 20:
|
| 118 |
-
risk_score += 0.1
|
| 119 |
-
|
| 120 |
-
if any(word in question.lower() for word in ['why', 'how', 'explain', 'compare']):
|
| 121 |
-
risk_score += 0.1
|
| 122 |
-
|
| 123 |
-
# Answer length (very short or very long answers might be risky)
|
| 124 |
-
if len(answer.split()) < 5 or len(answer.split()) > 100:
|
| 125 |
-
risk_score += 0.1
|
| 126 |
-
|
| 127 |
-
# Convert to binary label
|
| 128 |
-
label = 1 if risk_score > 0.3 else 0
|
| 129 |
-
labels.append(label)
|
| 130 |
-
|
| 131 |
-
logger.info(f"Created {sum(labels)} high-risk labels out of {len(labels)} total")
|
| 132 |
-
return labels
|
| 133 |
-
|
| 134 |
-
def cross_validate(self, X: np.ndarray, y: np.ndarray,
|
| 135 |
-
cv_folds: int = 5) -> Dict[str, List[float]]:
|
| 136 |
-
"""Perform cross-validation"""
|
| 137 |
-
from sklearn.model_selection import StratifiedKFold
|
| 138 |
-
|
| 139 |
-
skf = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=42)
|
| 140 |
-
|
| 141 |
-
fold_metrics = {
|
| 142 |
-
'accuracy': [],
|
| 143 |
-
'precision': [],
|
| 144 |
-
'recall': [],
|
| 145 |
-
'f1': [],
|
| 146 |
-
'auc': []
|
| 147 |
-
}
|
| 148 |
-
|
| 149 |
-
for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
|
| 150 |
-
logger.info(f"Training fold {fold + 1}/{cv_folds}")
|
| 151 |
-
|
| 152 |
-
X_train, X_val = X[train_idx], X[val_idx]
|
| 153 |
-
y_train, y_val = y[train_idx], y[val_idx]
|
| 154 |
-
|
| 155 |
-
# Train on fold
|
| 156 |
-
self.calibration_head.train(X_train, y_train)
|
| 157 |
-
|
| 158 |
-
# Evaluate on validation set
|
| 159 |
-
val_metrics = self.evaluate(X_val, y_val)
|
| 160 |
-
|
| 161 |
-
for metric, value in val_metrics.items():
|
| 162 |
-
fold_metrics[metric].append(value)
|
| 163 |
-
|
| 164 |
-
# Calculate mean and std
|
| 165 |
-
cv_results = {}
|
| 166 |
-
for metric, values in fold_metrics.items():
|
| 167 |
-
cv_results[f'{metric}_mean'] = np.mean(values)
|
| 168 |
-
cv_results[f'{metric}_std'] = np.std(values)
|
| 169 |
-
|
| 170 |
-
logger.info(f"Cross-validation results: {cv_results}")
|
| 171 |
-
return cv_results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config.yaml
DELETED
|
@@ -1,189 +0,0 @@
|
|
| 1 |
-
# SafeRAG Configuration File
|
| 2 |
-
|
| 3 |
-
# Model Configuration
|
| 4 |
-
models:
|
| 5 |
-
embedding:
|
| 6 |
-
name: "BAAI/bge-large-en-v1.5"
|
| 7 |
-
device: "cuda"
|
| 8 |
-
batch_size: 32
|
| 9 |
-
|
| 10 |
-
reranker:
|
| 11 |
-
name: "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
| 12 |
-
device: "cuda"
|
| 13 |
-
batch_size: 32
|
| 14 |
-
|
| 15 |
-
generator:
|
| 16 |
-
name: "openai/gpt-oss-20b"
|
| 17 |
-
tensor_parallel_size: 1
|
| 18 |
-
gpu_memory_utilization: 0.9
|
| 19 |
-
max_tokens: 512
|
| 20 |
-
temperature: 0.7
|
| 21 |
-
top_p: 0.9
|
| 22 |
-
|
| 23 |
-
calibration:
|
| 24 |
-
type: "logistic" # logistic, random_forest, mlp
|
| 25 |
-
input_dim: 16
|
| 26 |
-
hidden_dim: 64
|
| 27 |
-
|
| 28 |
-
# Data Configuration
|
| 29 |
-
data:
|
| 30 |
-
datasets:
|
| 31 |
-
- "hotpotqa"
|
| 32 |
-
- "triviaqa"
|
| 33 |
-
- "nq_open"
|
| 34 |
-
|
| 35 |
-
knowledge_base:
|
| 36 |
-
name: "wikipedia"
|
| 37 |
-
language: "en"
|
| 38 |
-
date: "20231101"
|
| 39 |
-
|
| 40 |
-
preprocessing:
|
| 41 |
-
max_sentence_length: 512
|
| 42 |
-
min_sentence_length: 20
|
| 43 |
-
cache_dir: "./cache"
|
| 44 |
-
|
| 45 |
-
# Index Configuration
|
| 46 |
-
index:
|
| 47 |
-
type: "ivf" # flat, ivf
|
| 48 |
-
dimension: 1024
|
| 49 |
-
nlist: 4096
|
| 50 |
-
save_path: "./index/safrag"
|
| 51 |
-
|
| 52 |
-
# Retrieval Configuration
|
| 53 |
-
retrieval:
|
| 54 |
-
k: 20
|
| 55 |
-
rerank_k: 10
|
| 56 |
-
batch_size: 32
|
| 57 |
-
similarity_threshold: 0.3
|
| 58 |
-
|
| 59 |
-
# Risk Calibration Configuration
|
| 60 |
-
calibration:
|
| 61 |
-
tau1: 0.3 # Low risk threshold
|
| 62 |
-
tau2: 0.7 # High risk threshold
|
| 63 |
-
|
| 64 |
-
features:
|
| 65 |
-
- "num_passages"
|
| 66 |
-
- "avg_similarity"
|
| 67 |
-
- "std_similarity"
|
| 68 |
-
- "max_similarity"
|
| 69 |
-
- "min_similarity"
|
| 70 |
-
- "score_variance"
|
| 71 |
-
- "avg_token_overlap"
|
| 72 |
-
- "max_token_overlap"
|
| 73 |
-
- "avg_entity_overlap"
|
| 74 |
-
- "max_entity_overlap"
|
| 75 |
-
- "passage_consistency"
|
| 76 |
-
- "passage_consistency_std"
|
| 77 |
-
- "min_passage_similarity"
|
| 78 |
-
- "diversity"
|
| 79 |
-
- "topic_variance"
|
| 80 |
-
|
| 81 |
-
# Evaluation Configuration
|
| 82 |
-
evaluation:
|
| 83 |
-
metrics:
|
| 84 |
-
qa:
|
| 85 |
-
- "exact_match"
|
| 86 |
-
- "f1"
|
| 87 |
-
- "rouge1"
|
| 88 |
-
- "rouge2"
|
| 89 |
-
- "rougeL"
|
| 90 |
-
|
| 91 |
-
attribution:
|
| 92 |
-
- "precision"
|
| 93 |
-
- "recall"
|
| 94 |
-
- "f1"
|
| 95 |
-
- "citation_coverage"
|
| 96 |
-
- "citation_accuracy"
|
| 97 |
-
|
| 98 |
-
calibration:
|
| 99 |
-
- "ece"
|
| 100 |
-
- "mce"
|
| 101 |
-
- "auroc"
|
| 102 |
-
- "auprc"
|
| 103 |
-
|
| 104 |
-
system:
|
| 105 |
-
- "throughput"
|
| 106 |
-
- "latency"
|
| 107 |
-
- "gpu_utilization"
|
| 108 |
-
- "memory_usage"
|
| 109 |
-
|
| 110 |
-
test_size: 0.2
|
| 111 |
-
random_state: 42
|
| 112 |
-
cv_folds: 5
|
| 113 |
-
|
| 114 |
-
# System Configuration
|
| 115 |
-
system:
|
| 116 |
-
device: "cuda"
|
| 117 |
-
num_workers: 4
|
| 118 |
-
batch_size: 32
|
| 119 |
-
max_memory_gb: 16
|
| 120 |
-
|
| 121 |
-
monitoring:
|
| 122 |
-
enabled: true
|
| 123 |
-
interval: 1 # seconds
|
| 124 |
-
metrics:
|
| 125 |
-
- "cpu"
|
| 126 |
-
- "memory"
|
| 127 |
-
- "gpu"
|
| 128 |
-
- "disk"
|
| 129 |
-
|
| 130 |
-
# Output Configuration
|
| 131 |
-
output:
|
| 132 |
-
results_dir: "./results"
|
| 133 |
-
logs_dir: "./logs"
|
| 134 |
-
models_dir: "./models"
|
| 135 |
-
plots_dir: "./plots"
|
| 136 |
-
|
| 137 |
-
formats:
|
| 138 |
-
- "json"
|
| 139 |
-
- "csv"
|
| 140 |
-
- "html"
|
| 141 |
-
|
| 142 |
-
save_predictions: true
|
| 143 |
-
save_features: true
|
| 144 |
-
save_plots: true
|
| 145 |
-
|
| 146 |
-
# Logging Configuration
|
| 147 |
-
logging:
|
| 148 |
-
level: "INFO"
|
| 149 |
-
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 150 |
-
file: "./logs/safrag.log"
|
| 151 |
-
max_size: "10MB"
|
| 152 |
-
backup_count: 5
|
| 153 |
-
|
| 154 |
-
# Hugging Face Configuration
|
| 155 |
-
huggingface:
|
| 156 |
-
cache_dir: "./cache"
|
| 157 |
-
token: null # Set your HF token here
|
| 158 |
-
hub_url: "https://huggingface.co"
|
| 159 |
-
|
| 160 |
-
spaces:
|
| 161 |
-
app_name: "safrag-demo"
|
| 162 |
-
hardware: "cpu" # cpu, gpu, cpu-basic, gpu-basic
|
| 163 |
-
visibility: "public"
|
| 164 |
-
|
| 165 |
-
# Experiment Configuration
|
| 166 |
-
experiments:
|
| 167 |
-
baseline:
|
| 168 |
-
enabled: true
|
| 169 |
-
output_dir: "./results/baseline"
|
| 170 |
-
|
| 171 |
-
safrag:
|
| 172 |
-
enabled: true
|
| 173 |
-
output_dir: "./results/safrag"
|
| 174 |
-
|
| 175 |
-
ablation:
|
| 176 |
-
enabled: true
|
| 177 |
-
output_dir: "./results/ablation"
|
| 178 |
-
|
| 179 |
-
studies:
|
| 180 |
-
- "no_reranking"
|
| 181 |
-
- "no_calibration"
|
| 182 |
-
- "different_embeddings"
|
| 183 |
-
- "different_thresholds"
|
| 184 |
-
- "different_calibration_models"
|
| 185 |
-
- "different_retrieval_k"
|
| 186 |
-
|
| 187 |
-
comprehensive:
|
| 188 |
-
enabled: true
|
| 189 |
-
output_dir: "./results/comprehensive"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_processing/__init__.py
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
from .data_loader import DataLoader
|
| 2 |
-
from .preprocessor import Preprocessor
|
| 3 |
-
|
| 4 |
-
__all__ = ['DataLoader', 'Preprocessor']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_processing/data_loader.py
DELETED
|
@@ -1,29 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
import logging
|
| 3 |
-
from datasets import load_dataset
|
| 4 |
-
|
| 5 |
-
logger = logging.getLogger(__name__)
|
| 6 |
-
|
| 7 |
-
class DataLoader:
|
| 8 |
-
def __init__(self, cache_dir: str = "./cache"):
|
| 9 |
-
self.cache_dir = cache_dir
|
| 10 |
-
|
| 11 |
-
def load_msmarco_passage(self, split: str = "train"):
|
| 12 |
-
"""Load MS MARCO Passage Ranking dataset from Hugging Face (v2.1)"""
|
| 13 |
-
try:
|
| 14 |
-
logger.info(f"Downloading MS MARCO Passage Ranking {split} (v2.1) from Hugging Face")
|
| 15 |
-
ds = load_dataset("ms_marco", "v2.1", split=split)
|
| 16 |
-
return ds
|
| 17 |
-
except Exception as e:
|
| 18 |
-
logger.error(f"Failed to load MS MARCO Passage Ranking: {e}")
|
| 19 |
-
raise
|
| 20 |
-
|
| 21 |
-
def get_passage_dataset(self, split: str = "train"):
|
| 22 |
-
"""Load MS MARCO Passage Ranking dataset"""
|
| 23 |
-
try:
|
| 24 |
-
ds = self.load_msmarco_passage(split)
|
| 25 |
-
logger.info("MS MARCO Passage Ranking loaded successfully")
|
| 26 |
-
return ds
|
| 27 |
-
except Exception as e:
|
| 28 |
-
logger.error(f"Failed to load MS MARCO Passage Ranking: {e}")
|
| 29 |
-
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_processing/preprocessor.py
DELETED
|
@@ -1,112 +0,0 @@
|
|
| 1 |
-
from typing import List, Dict, Any
|
| 2 |
-
import re
|
| 3 |
-
import logging
|
| 4 |
-
|
| 5 |
-
logger = logging.getLogger(__name__)
|
| 6 |
-
|
| 7 |
-
class Preprocessor:
|
| 8 |
-
def __init__(self):
|
| 9 |
-
"""Initialize preprocessor without external dependencies"""
|
| 10 |
-
pass
|
| 11 |
-
|
| 12 |
-
def clean_text(self, text: str) -> str:
|
| 13 |
-
"""Clean and normalize text"""
|
| 14 |
-
if not text:
|
| 15 |
-
return ""
|
| 16 |
-
|
| 17 |
-
# Remove extra whitespace
|
| 18 |
-
text = text.strip()
|
| 19 |
-
text = re.sub(r'\s+', ' ', text)
|
| 20 |
-
|
| 21 |
-
# Remove special characters but keep punctuation
|
| 22 |
-
text = re.sub(r'[^\w\s\.\,\!\?\;\:\-\(\)]', '', text)
|
| 23 |
-
|
| 24 |
-
return text.strip()
|
| 25 |
-
|
| 26 |
-
def extract_sentences(self, text: str) -> List[str]:
|
| 27 |
-
"""Extract sentences from text (simplified version without NLTK)"""
|
| 28 |
-
if not text:
|
| 29 |
-
return []
|
| 30 |
-
|
| 31 |
-
# Simple sentence splitting based on punctuation
|
| 32 |
-
sentences = re.split(r'[.!?]+', text)
|
| 33 |
-
sentences = [s.strip() for s in sentences if s.strip()]
|
| 34 |
-
|
| 35 |
-
return sentences
|
| 36 |
-
|
| 37 |
-
def tokenize(self, text: str) -> List[str]:
|
| 38 |
-
"""Tokenize text into words (simplified version)"""
|
| 39 |
-
if not text:
|
| 40 |
-
return []
|
| 41 |
-
|
| 42 |
-
# Simple word tokenization
|
| 43 |
-
words = re.findall(r'\b\w+\b', text.lower())
|
| 44 |
-
return words
|
| 45 |
-
|
| 46 |
-
def preprocess_passages(self, passages: List[str]) -> List[Dict[str, Any]]:
|
| 47 |
-
"""Preprocess a list of passages"""
|
| 48 |
-
processed = []
|
| 49 |
-
|
| 50 |
-
for i, passage in enumerate(passages):
|
| 51 |
-
if not passage:
|
| 52 |
-
continue
|
| 53 |
-
|
| 54 |
-
cleaned = self.clean_text(passage)
|
| 55 |
-
sentences = self.extract_sentences(cleaned)
|
| 56 |
-
tokens = self.tokenize(cleaned)
|
| 57 |
-
|
| 58 |
-
processed.append({
|
| 59 |
-
'id': i,
|
| 60 |
-
'text': cleaned,
|
| 61 |
-
'sentences': sentences,
|
| 62 |
-
'tokens': tokens,
|
| 63 |
-
'length': len(tokens)
|
| 64 |
-
})
|
| 65 |
-
|
| 66 |
-
return processed
|
| 67 |
-
|
| 68 |
-
def preprocess_qa_data(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 69 |
-
"""Preprocess QA data, auto convert dict/list fields to string"""
|
| 70 |
-
processed = []
|
| 71 |
-
def to_str(val):
|
| 72 |
-
if isinstance(val, dict):
|
| 73 |
-
# 拼接所有value
|
| 74 |
-
return " ".join([to_str(v) for v in val.values()])
|
| 75 |
-
elif isinstance(val, list):
|
| 76 |
-
return " ".join([to_str(v) for v in val])
|
| 77 |
-
elif val is None:
|
| 78 |
-
return ""
|
| 79 |
-
return str(val)
|
| 80 |
-
|
| 81 |
-
for item in data:
|
| 82 |
-
if not isinstance(item, dict):
|
| 83 |
-
continue
|
| 84 |
-
question = to_str(item.get('question', ''))
|
| 85 |
-
answer = to_str(item.get('answer', ''))
|
| 86 |
-
context = to_str(item.get('context', ''))
|
| 87 |
-
|
| 88 |
-
processed_item = {
|
| 89 |
-
'question': self.clean_text(question),
|
| 90 |
-
'answer': self.clean_text(answer),
|
| 91 |
-
'context': self.clean_text(context),
|
| 92 |
-
'question_tokens': self.tokenize(question),
|
| 93 |
-
'answer_tokens': self.tokenize(answer),
|
| 94 |
-
'context_tokens': self.tokenize(context)
|
| 95 |
-
}
|
| 96 |
-
processed.append(processed_item)
|
| 97 |
-
return processed
|
| 98 |
-
|
| 99 |
-
def create_chunks(self, text: str, chunk_size: int = 512, overlap: int = 50) -> List[str]:
|
| 100 |
-
"""Create overlapping text chunks"""
|
| 101 |
-
if not text:
|
| 102 |
-
return []
|
| 103 |
-
|
| 104 |
-
tokens = self.tokenize(text)
|
| 105 |
-
chunks = []
|
| 106 |
-
|
| 107 |
-
for i in range(0, len(tokens), chunk_size - overlap):
|
| 108 |
-
chunk_tokens = tokens[i:i + chunk_size]
|
| 109 |
-
chunk_text = ' '.join(chunk_tokens)
|
| 110 |
-
chunks.append(chunk_text)
|
| 111 |
-
|
| 112 |
-
return chunks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval/__init__.py
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
from .eval_qa import QAEvaluator
|
| 2 |
-
from .eval_attr import AttributionEvaluator
|
| 3 |
-
from .eval_calib import CalibrationEvaluator
|
| 4 |
-
from .eval_system import SystemEvaluator
|
| 5 |
-
|
| 6 |
-
__all__ = ['QAEvaluator', 'AttributionEvaluator', 'CalibrationEvaluator', 'SystemEvaluator']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval/eval_attr.py
DELETED
|
@@ -1,275 +0,0 @@
|
|
| 1 |
-
from typing import List, Dict, Any, Set
|
| 2 |
-
import numpy as np
|
| 3 |
-
from sentence_transformers import SentenceTransformer
|
| 4 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
| 5 |
-
import logging
|
| 6 |
-
|
| 7 |
-
logger = logging.getLogger(__name__)
|
| 8 |
-
|
| 9 |
-
class AttributionEvaluator:
|
| 10 |
-
def __init__(self, embedding_model: str = "BAAI/bge-large-en-v1.5"):
|
| 11 |
-
self.embedding_model = SentenceTransformer(embedding_model)
|
| 12 |
-
|
| 13 |
-
def evaluate_attribution(self, answers: List[str],
|
| 14 |
-
retrieved_passages: List[List[Dict[str, Any]]],
|
| 15 |
-
supporting_facts: List[List[str]] = None) -> Dict[str, float]:
|
| 16 |
-
"""Evaluate attribution quality"""
|
| 17 |
-
|
| 18 |
-
if not answers or not retrieved_passages:
|
| 19 |
-
return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}
|
| 20 |
-
|
| 21 |
-
precisions = []
|
| 22 |
-
recalls = []
|
| 23 |
-
f1_scores = []
|
| 24 |
-
|
| 25 |
-
for answer, passages, facts in zip(answers, retrieved_passages, supporting_facts or [[]] * len(answers)):
|
| 26 |
-
if not passages:
|
| 27 |
-
precisions.append(0.0)
|
| 28 |
-
recalls.append(0.0)
|
| 29 |
-
f1_scores.append(0.0)
|
| 30 |
-
continue
|
| 31 |
-
|
| 32 |
-
# Extract passage texts
|
| 33 |
-
passage_texts = [p.get('text', '') for p in passages]
|
| 34 |
-
|
| 35 |
-
# Calculate attribution metrics
|
| 36 |
-
if facts:
|
| 37 |
-
# Use provided supporting facts
|
| 38 |
-
precision, recall, f1 = self._calculate_attribution_metrics(
|
| 39 |
-
answer, passage_texts, facts
|
| 40 |
-
)
|
| 41 |
-
else:
|
| 42 |
-
# Use semantic similarity as proxy
|
| 43 |
-
precision, recall, f1 = self._calculate_semantic_attribution(
|
| 44 |
-
answer, passage_texts
|
| 45 |
-
)
|
| 46 |
-
|
| 47 |
-
precisions.append(precision)
|
| 48 |
-
recalls.append(recall)
|
| 49 |
-
f1_scores.append(f1)
|
| 50 |
-
|
| 51 |
-
return {
|
| 52 |
-
'precision': np.mean(precisions),
|
| 53 |
-
'recall': np.mean(recalls),
|
| 54 |
-
'f1': np.mean(f1_scores),
|
| 55 |
-
'precision_std': np.std(precisions),
|
| 56 |
-
'recall_std': np.std(recalls),
|
| 57 |
-
'f1_std': np.std(f1_scores)
|
| 58 |
-
}
|
| 59 |
-
|
| 60 |
-
def _calculate_attribution_metrics(self, answer: str, passages: List[str],
|
| 61 |
-
supporting_facts: List[str]) -> tuple:
|
| 62 |
-
"""Calculate attribution metrics using supporting facts"""
|
| 63 |
-
|
| 64 |
-
# Find which passages contain supporting facts
|
| 65 |
-
relevant_passages = set()
|
| 66 |
-
for fact in supporting_facts:
|
| 67 |
-
for i, passage in enumerate(passages):
|
| 68 |
-
if self._passage_contains_fact(passage, fact):
|
| 69 |
-
relevant_passages.add(i)
|
| 70 |
-
|
| 71 |
-
# Calculate metrics
|
| 72 |
-
total_passages = len(passages)
|
| 73 |
-
relevant_count = len(relevant_passages)
|
| 74 |
-
|
| 75 |
-
if total_passages == 0:
|
| 76 |
-
return 0.0, 0.0, 0.0
|
| 77 |
-
|
| 78 |
-
# Precision: relevant passages / total retrieved passages
|
| 79 |
-
precision = relevant_count / total_passages
|
| 80 |
-
|
| 81 |
-
# Recall: relevant passages / total supporting facts
|
| 82 |
-
recall = relevant_count / len(supporting_facts) if supporting_facts else 0.0
|
| 83 |
-
|
| 84 |
-
# F1 score
|
| 85 |
-
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
| 86 |
-
|
| 87 |
-
return precision, recall, f1
|
| 88 |
-
|
| 89 |
-
def _calculate_semantic_attribution(self, answer: str, passages: List[str]) -> tuple:
|
| 90 |
-
"""Calculate attribution using semantic similarity"""
|
| 91 |
-
|
| 92 |
-
if not passages:
|
| 93 |
-
return 0.0, 0.0, 0.0
|
| 94 |
-
|
| 95 |
-
# Encode answer and passages
|
| 96 |
-
answer_embedding = self.embedding_model.encode([answer])
|
| 97 |
-
passage_embeddings = self.embedding_model.encode(passages)
|
| 98 |
-
|
| 99 |
-
# Calculate similarities
|
| 100 |
-
similarities = cosine_similarity(answer_embedding, passage_embeddings)[0]
|
| 101 |
-
|
| 102 |
-
# Use threshold to determine relevant passages
|
| 103 |
-
threshold = 0.3
|
| 104 |
-
relevant_passages = similarities >= threshold
|
| 105 |
-
|
| 106 |
-
# Calculate metrics
|
| 107 |
-
total_passages = len(passages)
|
| 108 |
-
relevant_count = np.sum(relevant_passages)
|
| 109 |
-
|
| 110 |
-
precision = relevant_count / total_passages
|
| 111 |
-
recall = relevant_count / total_passages # Simplified for semantic method
|
| 112 |
-
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
| 113 |
-
|
| 114 |
-
return precision, recall, f1
|
| 115 |
-
|
| 116 |
-
def _passage_contains_fact(self, passage: str, fact: str) -> bool:
|
| 117 |
-
"""Check if passage contains a supporting fact"""
|
| 118 |
-
# Simple containment check
|
| 119 |
-
fact_words = set(fact.lower().split())
|
| 120 |
-
passage_words = set(passage.lower().split())
|
| 121 |
-
|
| 122 |
-
# Check if most fact words are in passage
|
| 123 |
-
overlap = len(fact_words & passage_words)
|
| 124 |
-
return overlap >= len(fact_words) * 0.7
|
| 125 |
-
|
| 126 |
-
def evaluate_citation_quality(self, answers: List[str],
|
| 127 |
-
citations: List[List[Dict[str, Any]]]) -> Dict[str, float]:
|
| 128 |
-
"""Evaluate citation quality in answers"""
|
| 129 |
-
|
| 130 |
-
if not answers or not citations:
|
| 131 |
-
return {'citation_coverage': 0.0, 'citation_accuracy': 0.0}
|
| 132 |
-
|
| 133 |
-
coverage_scores = []
|
| 134 |
-
accuracy_scores = []
|
| 135 |
-
|
| 136 |
-
for answer, answer_citations in zip(answers, citations):
|
| 137 |
-
# Citation coverage: percentage of answer that is cited
|
| 138 |
-
coverage = self._calculate_citation_coverage(answer, answer_citations)
|
| 139 |
-
coverage_scores.append(coverage)
|
| 140 |
-
|
| 141 |
-
# Citation accuracy: percentage of citations that are relevant
|
| 142 |
-
accuracy = self._calculate_citation_accuracy(answer, answer_citations)
|
| 143 |
-
accuracy_scores.append(accuracy)
|
| 144 |
-
|
| 145 |
-
return {
|
| 146 |
-
'citation_coverage': np.mean(coverage_scores),
|
| 147 |
-
'citation_accuracy': np.mean(accuracy_scores),
|
| 148 |
-
'citation_coverage_std': np.std(coverage_scores),
|
| 149 |
-
'citation_accuracy_std': np.std(accuracy_scores)
|
| 150 |
-
}
|
| 151 |
-
|
| 152 |
-
def _calculate_citation_coverage(self, answer: str, citations: List[Dict[str, Any]]) -> float:
|
| 153 |
-
"""Calculate what percentage of answer is covered by citations"""
|
| 154 |
-
if not citations:
|
| 155 |
-
return 0.0
|
| 156 |
-
|
| 157 |
-
# Simple heuristic: check if answer contains citation markers
|
| 158 |
-
import re
|
| 159 |
-
citation_markers = re.findall(r'\[\d+\]', answer)
|
| 160 |
-
|
| 161 |
-
if not citation_markers:
|
| 162 |
-
return 0.0
|
| 163 |
-
|
| 164 |
-
# Estimate coverage based on citation density
|
| 165 |
-
answer_length = len(answer.split())
|
| 166 |
-
citation_density = len(citation_markers) / answer_length if answer_length > 0 else 0
|
| 167 |
-
|
| 168 |
-
return min(1.0, citation_density * 10) # Scale factor
|
| 169 |
-
|
| 170 |
-
def _calculate_citation_accuracy(self, answer: str, citations: List[Dict[str, Any]]) -> float:
|
| 171 |
-
"""Calculate accuracy of citations"""
|
| 172 |
-
if not citations:
|
| 173 |
-
return 0.0
|
| 174 |
-
|
| 175 |
-
# Simple heuristic: check if cited passages are relevant to answer
|
| 176 |
-
answer_words = set(answer.lower().split())
|
| 177 |
-
relevant_citations = 0
|
| 178 |
-
|
| 179 |
-
for citation in citations:
|
| 180 |
-
citation_text = citation.get('text', '')
|
| 181 |
-
citation_words = set(citation_text.lower().split())
|
| 182 |
-
|
| 183 |
-
# Check word overlap
|
| 184 |
-
overlap = len(answer_words & citation_words)
|
| 185 |
-
if overlap >= 3: # Threshold for relevance
|
| 186 |
-
relevant_citations += 1
|
| 187 |
-
|
| 188 |
-
return relevant_citations / len(citations)
|
| 189 |
-
|
| 190 |
-
def evaluate_retrieval_quality(self, queries: List[str],
|
| 191 |
-
retrieved_passages: List[List[Dict[str, Any]]],
|
| 192 |
-
relevant_passages: List[List[str]] = None) -> Dict[str, float]:
|
| 193 |
-
"""Evaluate retrieval quality"""
|
| 194 |
-
|
| 195 |
-
if not queries or not retrieved_passages:
|
| 196 |
-
return {'retrieval_precision': 0.0, 'retrieval_recall': 0.0, 'retrieval_f1': 0.0}
|
| 197 |
-
|
| 198 |
-
precisions = []
|
| 199 |
-
recalls = []
|
| 200 |
-
f1_scores = []
|
| 201 |
-
|
| 202 |
-
for query, passages, relevant in zip(queries, retrieved_passages, relevant_passages or [[]] * len(queries)):
|
| 203 |
-
if not passages:
|
| 204 |
-
precisions.append(0.0)
|
| 205 |
-
recalls.append(0.0)
|
| 206 |
-
f1_scores.append(0.0)
|
| 207 |
-
continue
|
| 208 |
-
|
| 209 |
-
# Calculate retrieval metrics
|
| 210 |
-
if relevant:
|
| 211 |
-
precision, recall, f1 = self._calculate_retrieval_metrics(passages, relevant)
|
| 212 |
-
else:
|
| 213 |
-
# Use semantic similarity as proxy
|
| 214 |
-
precision, recall, f1 = self._calculate_semantic_retrieval(query, passages)
|
| 215 |
-
|
| 216 |
-
precisions.append(precision)
|
| 217 |
-
recalls.append(recall)
|
| 218 |
-
f1_scores.append(f1)
|
| 219 |
-
|
| 220 |
-
return {
|
| 221 |
-
'retrieval_precision': np.mean(precisions),
|
| 222 |
-
'retrieval_recall': np.mean(recalls),
|
| 223 |
-
'retrieval_f1': np.mean(f1_scores),
|
| 224 |
-
'retrieval_precision_std': np.std(precisions),
|
| 225 |
-
'retrieval_recall_std': np.std(recalls),
|
| 226 |
-
'retrieval_f1_std': np.std(f1_scores)
|
| 227 |
-
}
|
| 228 |
-
|
| 229 |
-
def _calculate_retrieval_metrics(self, passages: List[Dict[str, Any]],
|
| 230 |
-
relevant_passages: List[str]) -> tuple:
|
| 231 |
-
"""Calculate retrieval metrics using ground truth"""
|
| 232 |
-
|
| 233 |
-
retrieved_texts = [p.get('text', '') for p in passages]
|
| 234 |
-
|
| 235 |
-
# Find relevant retrieved passages
|
| 236 |
-
relevant_retrieved = 0
|
| 237 |
-
for retrieved in retrieved_texts:
|
| 238 |
-
for relevant in relevant_passages:
|
| 239 |
-
if self._passage_contains_fact(retrieved, relevant):
|
| 240 |
-
relevant_retrieved += 1
|
| 241 |
-
break
|
| 242 |
-
|
| 243 |
-
total_retrieved = len(passages)
|
| 244 |
-
total_relevant = len(relevant_passages)
|
| 245 |
-
|
| 246 |
-
precision = relevant_retrieved / total_retrieved if total_retrieved > 0 else 0.0
|
| 247 |
-
recall = relevant_retrieved / total_relevant if total_relevant > 0 else 0.0
|
| 248 |
-
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
| 249 |
-
|
| 250 |
-
return precision, recall, f1
|
| 251 |
-
|
| 252 |
-
def _calculate_semantic_retrieval(self, query: str, passages: List[Dict[str, Any]]) -> tuple:
|
| 253 |
-
"""Calculate retrieval metrics using semantic similarity"""
|
| 254 |
-
|
| 255 |
-
if not passages:
|
| 256 |
-
return 0.0, 0.0, 0.0
|
| 257 |
-
|
| 258 |
-
# Encode query and passages
|
| 259 |
-
query_embedding = self.embedding_model.encode([query])
|
| 260 |
-
passage_embeddings = self.embedding_model.encode([p.get('text', '') for p in passages])
|
| 261 |
-
|
| 262 |
-
# Calculate similarities
|
| 263 |
-
similarities = cosine_similarity(query_embedding, passage_embeddings)[0]
|
| 264 |
-
|
| 265 |
-
# Use threshold to determine relevant passages
|
| 266 |
-
threshold = 0.3
|
| 267 |
-
relevant_count = np.sum(similarities >= threshold)
|
| 268 |
-
|
| 269 |
-
total_retrieved = len(passages)
|
| 270 |
-
|
| 271 |
-
precision = relevant_count / total_retrieved
|
| 272 |
-
recall = relevant_count / total_retrieved # Simplified for semantic method
|
| 273 |
-
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
| 274 |
-
|
| 275 |
-
return precision, recall, f1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval/eval_calib.py
DELETED
|
@@ -1,269 +0,0 @@
|
|
| 1 |
-
from typing import List, Dict, Any
|
| 2 |
-
import numpy as np
|
| 3 |
-
from sklearn.metrics import roc_auc_score, average_precision_score
|
| 4 |
-
import matplotlib.pyplot as plt
|
| 5 |
-
import logging
|
| 6 |
-
|
| 7 |
-
logger = logging.getLogger(__name__)
|
| 8 |
-
|
| 9 |
-
class CalibrationEvaluator:
|
| 10 |
-
def __init__(self):
|
| 11 |
-
pass
|
| 12 |
-
|
| 13 |
-
def expected_calibration_error(self, predictions: List[float],
|
| 14 |
-
labels: List[int], n_bins: int = 10) -> float:
|
| 15 |
-
"""Calculate Expected Calibration Error (ECE)"""
|
| 16 |
-
|
| 17 |
-
if not predictions or not labels:
|
| 18 |
-
return 0.0
|
| 19 |
-
|
| 20 |
-
predictions = np.array(predictions)
|
| 21 |
-
labels = np.array(labels)
|
| 22 |
-
|
| 23 |
-
# Create bins
|
| 24 |
-
bin_boundaries = np.linspace(0, 1, n_bins + 1)
|
| 25 |
-
bin_lowers = bin_boundaries[:-1]
|
| 26 |
-
bin_uppers = bin_boundaries[1:]
|
| 27 |
-
|
| 28 |
-
ece = 0
|
| 29 |
-
for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
|
| 30 |
-
# Find predictions in this bin
|
| 31 |
-
in_bin = (predictions > bin_lower) & (predictions <= bin_upper)
|
| 32 |
-
prop_in_bin = in_bin.mean()
|
| 33 |
-
|
| 34 |
-
if prop_in_bin > 0:
|
| 35 |
-
# Calculate accuracy in this bin
|
| 36 |
-
accuracy_in_bin = labels[in_bin].mean()
|
| 37 |
-
avg_confidence_in_bin = predictions[in_bin].mean()
|
| 38 |
-
|
| 39 |
-
# Add to ECE
|
| 40 |
-
ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
|
| 41 |
-
|
| 42 |
-
return ece
|
| 43 |
-
|
| 44 |
-
def maximum_calibration_error(self, predictions: List[float],
|
| 45 |
-
labels: List[int], n_bins: int = 10) -> float:
|
| 46 |
-
"""Calculate Maximum Calibration Error (MCE)"""
|
| 47 |
-
|
| 48 |
-
if not predictions or not labels:
|
| 49 |
-
return 0.0
|
| 50 |
-
|
| 51 |
-
predictions = np.array(predictions)
|
| 52 |
-
labels = np.array(labels)
|
| 53 |
-
|
| 54 |
-
# Create bins
|
| 55 |
-
bin_boundaries = np.linspace(0, 1, n_bins + 1)
|
| 56 |
-
bin_lowers = bin_boundaries[:-1]
|
| 57 |
-
bin_uppers = bin_boundaries[1:]
|
| 58 |
-
|
| 59 |
-
mce = 0
|
| 60 |
-
for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
|
| 61 |
-
# Find predictions in this bin
|
| 62 |
-
in_bin = (predictions > bin_lower) & (predictions <= bin_upper)
|
| 63 |
-
|
| 64 |
-
if in_bin.sum() > 0:
|
| 65 |
-
# Calculate accuracy in this bin
|
| 66 |
-
accuracy_in_bin = labels[in_bin].mean()
|
| 67 |
-
avg_confidence_in_bin = predictions[in_bin].mean()
|
| 68 |
-
|
| 69 |
-
# Update MCE
|
| 70 |
-
mce = max(mce, np.abs(avg_confidence_in_bin - accuracy_in_bin))
|
| 71 |
-
|
| 72 |
-
return mce
|
| 73 |
-
|
| 74 |
-
def reliability_diagram(self, predictions: List[float], labels: List[int],
|
| 75 |
-
n_bins: int = 10, save_path: str = None) -> Dict[str, Any]:
|
| 76 |
-
"""Create reliability diagram"""
|
| 77 |
-
|
| 78 |
-
if not predictions or not labels:
|
| 79 |
-
return {}
|
| 80 |
-
|
| 81 |
-
predictions = np.array(predictions)
|
| 82 |
-
labels = np.array(labels)
|
| 83 |
-
|
| 84 |
-
# Create bins
|
| 85 |
-
bin_boundaries = np.linspace(0, 1, n_bins + 1)
|
| 86 |
-
bin_lowers = bin_boundaries[:-1]
|
| 87 |
-
bin_uppers = bin_boundaries[1:]
|
| 88 |
-
|
| 89 |
-
bin_centers = []
|
| 90 |
-
accuracies = []
|
| 91 |
-
confidences = []
|
| 92 |
-
counts = []
|
| 93 |
-
|
| 94 |
-
for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
|
| 95 |
-
# Find predictions in this bin
|
| 96 |
-
in_bin = (predictions > bin_lower) & (predictions <= bin_upper)
|
| 97 |
-
count = in_bin.sum()
|
| 98 |
-
|
| 99 |
-
if count > 0:
|
| 100 |
-
bin_center = (bin_lower + bin_upper) / 2
|
| 101 |
-
accuracy = labels[in_bin].mean()
|
| 102 |
-
confidence = predictions[in_bin].mean()
|
| 103 |
-
|
| 104 |
-
bin_centers.append(bin_center)
|
| 105 |
-
accuracies.append(accuracy)
|
| 106 |
-
confidences.append(confidence)
|
| 107 |
-
counts.append(count)
|
| 108 |
-
|
| 109 |
-
# Create plot
|
| 110 |
-
plt.figure(figsize=(8, 6))
|
| 111 |
-
plt.bar(bin_centers, accuracies, width=0.1, alpha=0.7, label='Accuracy')
|
| 112 |
-
plt.plot([0, 1], [0, 1], 'r--', label='Perfect Calibration')
|
| 113 |
-
plt.xlabel('Confidence')
|
| 114 |
-
plt.ylabel('Accuracy')
|
| 115 |
-
plt.title('Reliability Diagram')
|
| 116 |
-
plt.legend()
|
| 117 |
-
plt.grid(True, alpha=0.3)
|
| 118 |
-
|
| 119 |
-
if save_path:
|
| 120 |
-
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 121 |
-
|
| 122 |
-
plt.close()
|
| 123 |
-
|
| 124 |
-
return {
|
| 125 |
-
'bin_centers': bin_centers,
|
| 126 |
-
'accuracies': accuracies,
|
| 127 |
-
'confidences': confidences,
|
| 128 |
-
'counts': counts
|
| 129 |
-
}
|
| 130 |
-
|
| 131 |
-
def auroc(self, predictions: List[float], labels: List[int]) -> float:
|
| 132 |
-
"""Calculate Area Under ROC Curve"""
|
| 133 |
-
if not predictions or not labels:
|
| 134 |
-
return 0.0
|
| 135 |
-
|
| 136 |
-
try:
|
| 137 |
-
return roc_auc_score(labels, predictions)
|
| 138 |
-
except:
|
| 139 |
-
return 0.0
|
| 140 |
-
|
| 141 |
-
def auprc(self, predictions: List[float], labels: List[int]) -> float:
|
| 142 |
-
"""Calculate Area Under Precision-Recall Curve"""
|
| 143 |
-
if not predictions or not labels:
|
| 144 |
-
return 0.0
|
| 145 |
-
|
| 146 |
-
try:
|
| 147 |
-
return average_precision_score(labels, predictions)
|
| 148 |
-
except:
|
| 149 |
-
return 0.0
|
| 150 |
-
|
| 151 |
-
def risk_coverage_curve(self, predictions: List[float], labels: List[int],
|
| 152 |
-
risk_thresholds: List[float] = None) -> Dict[str, Any]:
|
| 153 |
-
"""Calculate risk-coverage curve"""
|
| 154 |
-
|
| 155 |
-
if not predictions or not labels:
|
| 156 |
-
return {'thresholds': [], 'coverage': [], 'accuracy': []}
|
| 157 |
-
|
| 158 |
-
predictions = np.array(predictions)
|
| 159 |
-
labels = np.array(labels)
|
| 160 |
-
|
| 161 |
-
if risk_thresholds is None:
|
| 162 |
-
risk_thresholds = np.linspace(0, 1, 21)
|
| 163 |
-
|
| 164 |
-
coverages = []
|
| 165 |
-
accuracies = []
|
| 166 |
-
|
| 167 |
-
for threshold in risk_thresholds:
|
| 168 |
-
# Select predictions with risk <= threshold
|
| 169 |
-
selected = predictions <= threshold
|
| 170 |
-
|
| 171 |
-
if selected.sum() > 0:
|
| 172 |
-
coverage = selected.mean()
|
| 173 |
-
accuracy = labels[selected].mean()
|
| 174 |
-
else:
|
| 175 |
-
coverage = 0.0
|
| 176 |
-
accuracy = 0.0
|
| 177 |
-
|
| 178 |
-
coverages.append(coverage)
|
| 179 |
-
accuracies.append(accuracy)
|
| 180 |
-
|
| 181 |
-
return {
|
| 182 |
-
'thresholds': risk_thresholds.tolist(),
|
| 183 |
-
'coverage': coverages,
|
| 184 |
-
'accuracy': accuracies
|
| 185 |
-
}
|
| 186 |
-
|
| 187 |
-
def evaluate_calibration(self, predictions: List[float], labels: List[int]) -> Dict[str, float]:
|
| 188 |
-
"""Comprehensive calibration evaluation"""
|
| 189 |
-
|
| 190 |
-
if not predictions or not labels:
|
| 191 |
-
return {
|
| 192 |
-
'ece': 0.0,
|
| 193 |
-
'mce': 0.0,
|
| 194 |
-
'auroc': 0.0,
|
| 195 |
-
'auprc': 0.0
|
| 196 |
-
}
|
| 197 |
-
|
| 198 |
-
metrics = {
|
| 199 |
-
'ece': self.expected_calibration_error(predictions, labels),
|
| 200 |
-
'mce': self.maximum_calibration_error(predictions, labels),
|
| 201 |
-
'auroc': self.auroc(predictions, labels),
|
| 202 |
-
'auprc': self.auprc(predictions, labels)
|
| 203 |
-
}
|
| 204 |
-
|
| 205 |
-
# Risk-coverage analysis
|
| 206 |
-
risk_coverage = self.risk_coverage_curve(predictions, labels)
|
| 207 |
-
metrics['risk_coverage'] = risk_coverage
|
| 208 |
-
|
| 209 |
-
return metrics
|
| 210 |
-
|
| 211 |
-
def plot_calibration_curves(self, predictions: List[float], labels: List[int],
|
| 212 |
-
save_path: str = None) -> None:
|
| 213 |
-
"""Plot calibration curves"""
|
| 214 |
-
|
| 215 |
-
if not predictions or not labels:
|
| 216 |
-
return
|
| 217 |
-
|
| 218 |
-
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
| 219 |
-
|
| 220 |
-
# Reliability diagram
|
| 221 |
-
reliability_data = self.reliability_diagram(predictions, labels)
|
| 222 |
-
if reliability_data:
|
| 223 |
-
axes[0, 0].bar(reliability_data['bin_centers'], reliability_data['accuracies'],
|
| 224 |
-
width=0.1, alpha=0.7)
|
| 225 |
-
axes[0, 0].plot([0, 1], [0, 1], 'r--')
|
| 226 |
-
axes[0, 0].set_xlabel('Confidence')
|
| 227 |
-
axes[0, 0].set_ylabel('Accuracy')
|
| 228 |
-
axes[0, 0].set_title('Reliability Diagram')
|
| 229 |
-
axes[0, 0].grid(True, alpha=0.3)
|
| 230 |
-
|
| 231 |
-
# Risk-coverage curve
|
| 232 |
-
risk_coverage = self.risk_coverage_curve(predictions, labels)
|
| 233 |
-
if risk_coverage['thresholds']:
|
| 234 |
-
axes[0, 1].plot(risk_coverage['coverage'], risk_coverage['accuracy'], 'b-')
|
| 235 |
-
axes[0, 1].set_xlabel('Coverage')
|
| 236 |
-
axes[0, 1].set_ylabel('Accuracy')
|
| 237 |
-
axes[0, 1].set_title('Risk-Coverage Curve')
|
| 238 |
-
axes[0, 1].grid(True, alpha=0.3)
|
| 239 |
-
|
| 240 |
-
# Confidence distribution
|
| 241 |
-
axes[1, 0].hist(predictions, bins=20, alpha=0.7, edgecolor='black')
|
| 242 |
-
axes[1, 0].set_xlabel('Confidence')
|
| 243 |
-
axes[1, 0].set_ylabel('Count')
|
| 244 |
-
axes[1, 0].set_title('Confidence Distribution')
|
| 245 |
-
axes[1, 0].grid(True, alpha=0.3)
|
| 246 |
-
|
| 247 |
-
# Accuracy vs Confidence
|
| 248 |
-
bin_centers = np.linspace(0, 1, 11)
|
| 249 |
-
accuracies = []
|
| 250 |
-
for i in range(len(bin_centers) - 1):
|
| 251 |
-
mask = (np.array(predictions) >= bin_centers[i]) & (np.array(predictions) < bin_centers[i + 1])
|
| 252 |
-
if mask.sum() > 0:
|
| 253 |
-
accuracies.append(np.array(labels)[mask].mean())
|
| 254 |
-
else:
|
| 255 |
-
accuracies.append(0)
|
| 256 |
-
|
| 257 |
-
axes[1, 1].plot(bin_centers[:-1], accuracies, 'bo-')
|
| 258 |
-
axes[1, 1].plot([0, 1], [0, 1], 'r--')
|
| 259 |
-
axes[1, 1].set_xlabel('Confidence')
|
| 260 |
-
axes[1, 1].set_ylabel('Accuracy')
|
| 261 |
-
axes[1, 1].set_title('Accuracy vs Confidence')
|
| 262 |
-
axes[1, 1].grid(True, alpha=0.3)
|
| 263 |
-
|
| 264 |
-
plt.tight_layout()
|
| 265 |
-
|
| 266 |
-
if save_path:
|
| 267 |
-
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 268 |
-
|
| 269 |
-
plt.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval/eval_qa.py
DELETED
|
@@ -1,137 +0,0 @@
|
|
| 1 |
-
import re
|
| 2 |
-
from typing import List, Dict, Any
|
| 3 |
-
import numpy as np
|
| 4 |
-
from evaluate import load
|
| 5 |
-
import logging
|
| 6 |
-
|
| 7 |
-
logger = logging.getLogger(__name__)
|
| 8 |
-
|
| 9 |
-
class QAEvaluator:
|
| 10 |
-
def __init__(self):
|
| 11 |
-
self.squad_metric = load("squad")
|
| 12 |
-
self.rouge_metric = load("rouge")
|
| 13 |
-
|
| 14 |
-
def exact_match(self, predictions: List[str], references: List[str]) -> float:
|
| 15 |
-
"""Calculate exact match score"""
|
| 16 |
-
matches = 0
|
| 17 |
-
for pred, ref in zip(predictions, references):
|
| 18 |
-
if self._normalize_answer(pred) == self._normalize_answer(ref):
|
| 19 |
-
matches += 1
|
| 20 |
-
return matches / len(predictions) if predictions else 0.0
|
| 21 |
-
|
| 22 |
-
def f1_score(self, predictions: List[str], references: List[str]) -> float:
|
| 23 |
-
"""Calculate F1 score"""
|
| 24 |
-
f1_scores = []
|
| 25 |
-
for pred, ref in zip(predictions, references):
|
| 26 |
-
f1 = self._calculate_f1(pred, ref)
|
| 27 |
-
f1_scores.append(f1)
|
| 28 |
-
return np.mean(f1_scores) if f1_scores else 0.0
|
| 29 |
-
|
| 30 |
-
def rouge_score(self, predictions: List[str], references: List[str]) -> Dict[str, float]:
|
| 31 |
-
"""Calculate ROUGE scores"""
|
| 32 |
-
if not predictions or not references:
|
| 33 |
-
return {'rouge1': 0.0, 'rouge2': 0.0, 'rougeL': 0.0}
|
| 34 |
-
|
| 35 |
-
results = self.rouge_metric.compute(
|
| 36 |
-
predictions=predictions,
|
| 37 |
-
references=references
|
| 38 |
-
)
|
| 39 |
-
|
| 40 |
-
return {
|
| 41 |
-
'rouge1': results['rouge1'],
|
| 42 |
-
'rouge2': results['rouge2'],
|
| 43 |
-
'rougeL': results['rougeL']
|
| 44 |
-
}
|
| 45 |
-
|
| 46 |
-
def squad_metrics(self, predictions: List[str], references: List[str]) -> Dict[str, float]:
|
| 47 |
-
"""Calculate SQuAD-style metrics"""
|
| 48 |
-
if not predictions or not references:
|
| 49 |
-
return {'exact_match': 0.0, 'f1': 0.0}
|
| 50 |
-
|
| 51 |
-
# Format for SQuAD metric
|
| 52 |
-
formatted_predictions = [{"prediction_text": pred, "id": str(i)}
|
| 53 |
-
for i, pred in enumerate(predictions)]
|
| 54 |
-
formatted_references = [{"answers": {"text": [ref], "answer_start": [0]}, "id": str(i)}
|
| 55 |
-
for i, ref in enumerate(references)]
|
| 56 |
-
|
| 57 |
-
results = self.squad_metric.compute(
|
| 58 |
-
predictions=formatted_predictions,
|
| 59 |
-
references=formatted_references
|
| 60 |
-
)
|
| 61 |
-
|
| 62 |
-
return {
|
| 63 |
-
'exact_match': results['exact_match'],
|
| 64 |
-
'f1': results['f1']
|
| 65 |
-
}
|
| 66 |
-
|
| 67 |
-
def evaluate_batch(self, predictions: List[str], references: List[str]) -> Dict[str, float]:
|
| 68 |
-
"""Evaluate a batch of predictions"""
|
| 69 |
-
metrics = {}
|
| 70 |
-
|
| 71 |
-
# Basic metrics
|
| 72 |
-
metrics['exact_match'] = self.exact_match(predictions, references)
|
| 73 |
-
metrics['f1'] = self.f1_score(predictions, references)
|
| 74 |
-
|
| 75 |
-
# ROUGE metrics
|
| 76 |
-
rouge_scores = self.rouge_score(predictions, references)
|
| 77 |
-
metrics.update(rouge_scores)
|
| 78 |
-
|
| 79 |
-
# SQuAD metrics
|
| 80 |
-
squad_scores = self.squad_metrics(predictions, references)
|
| 81 |
-
metrics.update(squad_scores)
|
| 82 |
-
|
| 83 |
-
return metrics
|
| 84 |
-
|
| 85 |
-
def _normalize_answer(self, answer: str) -> str:
|
| 86 |
-
"""Normalize answer for comparison"""
|
| 87 |
-
def remove_articles(text):
|
| 88 |
-
return re.sub(r'\b(a|an|the)\b', ' ', text)
|
| 89 |
-
|
| 90 |
-
def white_space_fix(text):
|
| 91 |
-
return ' '.join(text.split())
|
| 92 |
-
|
| 93 |
-
def remove_punc(text):
|
| 94 |
-
exclude = set(string.punctuation)
|
| 95 |
-
return ''.join(ch for ch in text if ch not in exclude)
|
| 96 |
-
|
| 97 |
-
def lower(text):
|
| 98 |
-
return text.lower()
|
| 99 |
-
|
| 100 |
-
return white_space_fix(remove_articles(remove_punc(lower(answer))))
|
| 101 |
-
|
| 102 |
-
def _calculate_f1(self, prediction: str, reference: str) -> float:
|
| 103 |
-
"""Calculate F1 score between prediction and reference"""
|
| 104 |
-
pred_tokens = self._normalize_answer(prediction).split()
|
| 105 |
-
ref_tokens = self._normalize_answer(reference).split()
|
| 106 |
-
|
| 107 |
-
if len(ref_tokens) == 0:
|
| 108 |
-
return 1.0 if len(pred_tokens) == 0 else 0.0
|
| 109 |
-
|
| 110 |
-
common = set(pred_tokens) & set(ref_tokens)
|
| 111 |
-
|
| 112 |
-
if len(common) == 0:
|
| 113 |
-
return 0.0
|
| 114 |
-
|
| 115 |
-
precision = len(common) / len(pred_tokens)
|
| 116 |
-
recall = len(common) / len(ref_tokens)
|
| 117 |
-
|
| 118 |
-
f1 = 2 * precision * recall / (precision + recall)
|
| 119 |
-
return f1
|
| 120 |
-
|
| 121 |
-
def evaluate_with_context(self, predictions: List[str], references: List[str],
|
| 122 |
-
contexts: List[str]) -> Dict[str, float]:
|
| 123 |
-
"""Evaluate with context awareness"""
|
| 124 |
-
metrics = self.evaluate_batch(predictions, references)
|
| 125 |
-
|
| 126 |
-
# Context-based metrics
|
| 127 |
-
context_scores = []
|
| 128 |
-
for pred, context in zip(predictions, contexts):
|
| 129 |
-
# Check if prediction is supported by context
|
| 130 |
-
pred_words = set(pred.lower().split())
|
| 131 |
-
context_words = set(context.lower().split())
|
| 132 |
-
overlap = len(pred_words & context_words) / len(pred_words) if pred_words else 0
|
| 133 |
-
context_scores.append(overlap)
|
| 134 |
-
|
| 135 |
-
metrics['context_support'] = np.mean(context_scores)
|
| 136 |
-
|
| 137 |
-
return metrics
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval/eval_system.py
DELETED
|
@@ -1,297 +0,0 @@
|
|
| 1 |
-
import time
|
| 2 |
-
import psutil
|
| 3 |
-
import GPUtil
|
| 4 |
-
from typing import List, Dict, Any, Optional
|
| 5 |
-
import numpy as np
|
| 6 |
-
import logging
|
| 7 |
-
import threading
|
| 8 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 9 |
-
|
| 10 |
-
logger = logging.getLogger(__name__)
|
| 11 |
-
|
| 12 |
-
class SystemEvaluator:
|
| 13 |
-
def __init__(self):
|
| 14 |
-
self.monitoring = False
|
| 15 |
-
self.metrics = []
|
| 16 |
-
self.monitor_thread = None
|
| 17 |
-
|
| 18 |
-
def start_monitoring(self):
|
| 19 |
-
"""Start system monitoring"""
|
| 20 |
-
self.monitoring = True
|
| 21 |
-
self.metrics = []
|
| 22 |
-
self.monitor_thread = threading.Thread(target=self._monitor_system)
|
| 23 |
-
self.monitor_thread.start()
|
| 24 |
-
logger.info("Started system monitoring")
|
| 25 |
-
|
| 26 |
-
def stop_monitoring(self):
|
| 27 |
-
"""Stop system monitoring"""
|
| 28 |
-
self.monitoring = False
|
| 29 |
-
if self.monitor_thread:
|
| 30 |
-
self.monitor_thread.join()
|
| 31 |
-
logger.info("Stopped system monitoring")
|
| 32 |
-
|
| 33 |
-
def _monitor_system(self):
|
| 34 |
-
"""Monitor system resources"""
|
| 35 |
-
while self.monitoring:
|
| 36 |
-
try:
|
| 37 |
-
# CPU usage
|
| 38 |
-
cpu_percent = psutil.cpu_percent(interval=1)
|
| 39 |
-
|
| 40 |
-
# Memory usage
|
| 41 |
-
memory = psutil.virtual_memory()
|
| 42 |
-
memory_percent = memory.percent
|
| 43 |
-
memory_used_gb = memory.used / (1024**3)
|
| 44 |
-
|
| 45 |
-
# GPU usage (if available)
|
| 46 |
-
gpu_metrics = self._get_gpu_metrics()
|
| 47 |
-
|
| 48 |
-
# Disk usage
|
| 49 |
-
disk = psutil.disk_usage('/')
|
| 50 |
-
disk_percent = disk.percent
|
| 51 |
-
|
| 52 |
-
metric = {
|
| 53 |
-
'timestamp': time.time(),
|
| 54 |
-
'cpu_percent': cpu_percent,
|
| 55 |
-
'memory_percent': memory_percent,
|
| 56 |
-
'memory_used_gb': memory_used_gb,
|
| 57 |
-
'disk_percent': disk_percent,
|
| 58 |
-
**gpu_metrics
|
| 59 |
-
}
|
| 60 |
-
|
| 61 |
-
self.metrics.append(metric)
|
| 62 |
-
|
| 63 |
-
except Exception as e:
|
| 64 |
-
logger.error(f"Error monitoring system: {e}")
|
| 65 |
-
|
| 66 |
-
time.sleep(1) # Monitor every second
|
| 67 |
-
|
| 68 |
-
def _get_gpu_metrics(self) -> Dict[str, Any]:
|
| 69 |
-
"""Get GPU metrics"""
|
| 70 |
-
try:
|
| 71 |
-
gpus = GPUtil.getGPUs()
|
| 72 |
-
if gpus:
|
| 73 |
-
gpu = gpus[0] # Use first GPU
|
| 74 |
-
return {
|
| 75 |
-
'gpu_utilization': gpu.load * 100,
|
| 76 |
-
'gpu_memory_used': gpu.memoryUsed,
|
| 77 |
-
'gpu_memory_total': gpu.memoryTotal,
|
| 78 |
-
'gpu_memory_percent': (gpu.memoryUsed / gpu.memoryTotal) * 100,
|
| 79 |
-
'gpu_temperature': gpu.temperature
|
| 80 |
-
}
|
| 81 |
-
except:
|
| 82 |
-
pass
|
| 83 |
-
|
| 84 |
-
return {
|
| 85 |
-
'gpu_utilization': 0,
|
| 86 |
-
'gpu_memory_used': 0,
|
| 87 |
-
'gpu_memory_total': 0,
|
| 88 |
-
'gpu_memory_percent': 0,
|
| 89 |
-
'gpu_temperature': 0
|
| 90 |
-
}
|
| 91 |
-
|
| 92 |
-
def measure_throughput(self, func, args_list: List[tuple],
|
| 93 |
-
max_workers: int = 4) -> Dict[str, Any]:
|
| 94 |
-
"""Measure throughput of a function"""
|
| 95 |
-
|
| 96 |
-
start_time = time.time()
|
| 97 |
-
|
| 98 |
-
# Execute function with different concurrency levels
|
| 99 |
-
results = []
|
| 100 |
-
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 101 |
-
futures = [executor.submit(func, *args) for args in args_list]
|
| 102 |
-
|
| 103 |
-
for future in as_completed(futures):
|
| 104 |
-
try:
|
| 105 |
-
result = future.result()
|
| 106 |
-
results.append(result)
|
| 107 |
-
except Exception as e:
|
| 108 |
-
logger.error(f"Error in throughput measurement: {e}")
|
| 109 |
-
|
| 110 |
-
end_time = time.time()
|
| 111 |
-
|
| 112 |
-
total_time = end_time - start_time
|
| 113 |
-
throughput = len(results) / total_time # queries per second
|
| 114 |
-
|
| 115 |
-
return {
|
| 116 |
-
'total_queries': len(args_list),
|
| 117 |
-
'successful_queries': len(results),
|
| 118 |
-
'total_time': total_time,
|
| 119 |
-
'throughput_qps': throughput,
|
| 120 |
-
'avg_time_per_query': total_time / len(args_list) if args_list else 0
|
| 121 |
-
}
|
| 122 |
-
|
| 123 |
-
def measure_latency(self, func, args: tuple, num_runs: int = 10) -> Dict[str, Any]:
|
| 124 |
-
"""Measure latency of a function"""
|
| 125 |
-
|
| 126 |
-
latencies = []
|
| 127 |
-
|
| 128 |
-
for _ in range(num_runs):
|
| 129 |
-
start_time = time.time()
|
| 130 |
-
try:
|
| 131 |
-
result = func(*args)
|
| 132 |
-
end_time = time.time()
|
| 133 |
-
latency = end_time - start_time
|
| 134 |
-
latencies.append(latency)
|
| 135 |
-
except Exception as e:
|
| 136 |
-
logger.error(f"Error in latency measurement: {e}")
|
| 137 |
-
latencies.append(float('inf'))
|
| 138 |
-
|
| 139 |
-
# Remove infinite latencies
|
| 140 |
-
latencies = [l for l in latencies if l != float('inf')]
|
| 141 |
-
|
| 142 |
-
if not latencies:
|
| 143 |
-
return {
|
| 144 |
-
'avg_latency': 0,
|
| 145 |
-
'p50_latency': 0,
|
| 146 |
-
'p95_latency': 0,
|
| 147 |
-
'p99_latency': 0,
|
| 148 |
-
'min_latency': 0,
|
| 149 |
-
'max_latency': 0,
|
| 150 |
-
'std_latency': 0
|
| 151 |
-
}
|
| 152 |
-
|
| 153 |
-
latencies = np.array(latencies)
|
| 154 |
-
|
| 155 |
-
return {
|
| 156 |
-
'avg_latency': np.mean(latencies),
|
| 157 |
-
'p50_latency': np.percentile(latencies, 50),
|
| 158 |
-
'p95_latency': np.percentile(latencies, 95),
|
| 159 |
-
'p99_latency': np.percentile(latencies, 99),
|
| 160 |
-
'min_latency': np.min(latencies),
|
| 161 |
-
'max_latency': np.max(latencies),
|
| 162 |
-
'std_latency': np.std(latencies)
|
| 163 |
-
}
|
| 164 |
-
|
| 165 |
-
def measure_batch_latency(self, func, args_list: List[tuple],
|
| 166 |
-
batch_sizes: List[int] = [1, 4, 8, 16]) -> Dict[str, Any]:
|
| 167 |
-
"""Measure latency for different batch sizes"""
|
| 168 |
-
|
| 169 |
-
results = {}
|
| 170 |
-
|
| 171 |
-
for batch_size in batch_sizes:
|
| 172 |
-
batch_latencies = []
|
| 173 |
-
|
| 174 |
-
# Process in batches
|
| 175 |
-
for i in range(0, len(args_list), batch_size):
|
| 176 |
-
batch_args = args_list[i:i + batch_size]
|
| 177 |
-
|
| 178 |
-
start_time = time.time()
|
| 179 |
-
try:
|
| 180 |
-
batch_results = [func(*args) for args in batch_args]
|
| 181 |
-
end_time = time.time()
|
| 182 |
-
|
| 183 |
-
batch_latency = end_time - start_time
|
| 184 |
-
batch_latencies.append(batch_latency)
|
| 185 |
-
|
| 186 |
-
except Exception as e:
|
| 187 |
-
logger.error(f"Error in batch latency measurement: {e}")
|
| 188 |
-
|
| 189 |
-
if batch_latencies:
|
| 190 |
-
results[f'batch_size_{batch_size}'] = {
|
| 191 |
-
'avg_latency': np.mean(batch_latencies),
|
| 192 |
-
'p95_latency': np.percentile(batch_latencies, 95),
|
| 193 |
-
'throughput': batch_size / np.mean(batch_latencies)
|
| 194 |
-
}
|
| 195 |
-
|
| 196 |
-
return results
|
| 197 |
-
|
| 198 |
-
def get_system_stats(self) -> Dict[str, Any]:
|
| 199 |
-
"""Get current system statistics"""
|
| 200 |
-
|
| 201 |
-
if not self.metrics:
|
| 202 |
-
return {}
|
| 203 |
-
|
| 204 |
-
# Calculate statistics from monitoring data
|
| 205 |
-
cpu_values = [m['cpu_percent'] for m in self.metrics]
|
| 206 |
-
memory_values = [m['memory_percent'] for m in self.metrics]
|
| 207 |
-
gpu_values = [m.get('gpu_utilization', 0) for m in self.metrics]
|
| 208 |
-
|
| 209 |
-
return {
|
| 210 |
-
'monitoring_duration': len(self.metrics),
|
| 211 |
-
'cpu': {
|
| 212 |
-
'avg': np.mean(cpu_values),
|
| 213 |
-
'max': np.max(cpu_values),
|
| 214 |
-
'min': np.min(cpu_values),
|
| 215 |
-
'std': np.std(cpu_values)
|
| 216 |
-
},
|
| 217 |
-
'memory': {
|
| 218 |
-
'avg': np.mean(memory_values),
|
| 219 |
-
'max': np.max(memory_values),
|
| 220 |
-
'min': np.min(memory_values),
|
| 221 |
-
'std': np.std(memory_values)
|
| 222 |
-
},
|
| 223 |
-
'gpu': {
|
| 224 |
-
'avg': np.mean(gpu_values),
|
| 225 |
-
'max': np.max(gpu_values),
|
| 226 |
-
'min': np.min(gpu_values),
|
| 227 |
-
'std': np.std(gpu_values)
|
| 228 |
-
}
|
| 229 |
-
}
|
| 230 |
-
|
| 231 |
-
def evaluate_retrieval_performance(self, retriever, queries: List[str],
|
| 232 |
-
k: int = 10) -> Dict[str, Any]:
|
| 233 |
-
"""Evaluate retrieval performance"""
|
| 234 |
-
|
| 235 |
-
# Measure latency
|
| 236 |
-
latency_stats = self.measure_latency(
|
| 237 |
-
retriever.retrieve_single,
|
| 238 |
-
(queries[0], k),
|
| 239 |
-
num_runs=5
|
| 240 |
-
)
|
| 241 |
-
|
| 242 |
-
# Measure throughput
|
| 243 |
-
throughput_stats = self.measure_throughput(
|
| 244 |
-
retriever.retrieve_single,
|
| 245 |
-
[(query, k) for query in queries[:10]], # Limit for throughput test
|
| 246 |
-
max_workers=4
|
| 247 |
-
)
|
| 248 |
-
|
| 249 |
-
return {
|
| 250 |
-
'latency': latency_stats,
|
| 251 |
-
'throughput': throughput_stats
|
| 252 |
-
}
|
| 253 |
-
|
| 254 |
-
def evaluate_generation_performance(self, generator, questions: List[str],
|
| 255 |
-
passages_list: List[List[Dict[str, Any]]]) -> Dict[str, Any]:
|
| 256 |
-
"""Evaluate generation performance"""
|
| 257 |
-
|
| 258 |
-
# Measure latency
|
| 259 |
-
latency_stats = self.measure_latency(
|
| 260 |
-
generator.generate_with_strategy,
|
| 261 |
-
(questions[0], passages_list[0]),
|
| 262 |
-
num_runs=5
|
| 263 |
-
)
|
| 264 |
-
|
| 265 |
-
# Measure throughput
|
| 266 |
-
throughput_stats = self.measure_throughput(
|
| 267 |
-
generator.generate_with_strategy,
|
| 268 |
-
list(zip(questions[:5], passages_list[:5])), # Limit for throughput test
|
| 269 |
-
max_workers=2
|
| 270 |
-
)
|
| 271 |
-
|
| 272 |
-
return {
|
| 273 |
-
'latency': latency_stats,
|
| 274 |
-
'throughput': throughput_stats
|
| 275 |
-
}
|
| 276 |
-
|
| 277 |
-
def evaluate_end_to_end_performance(self, rag_system, queries: List[str]) -> Dict[str, Any]:
|
| 278 |
-
"""Evaluate end-to-end RAG performance"""
|
| 279 |
-
|
| 280 |
-
# Measure latency
|
| 281 |
-
latency_stats = self.measure_latency(
|
| 282 |
-
rag_system.query,
|
| 283 |
-
(queries[0],),
|
| 284 |
-
num_runs=5
|
| 285 |
-
)
|
| 286 |
-
|
| 287 |
-
# Measure throughput
|
| 288 |
-
throughput_stats = self.measure_throughput(
|
| 289 |
-
rag_system.query,
|
| 290 |
-
[(query,) for query in queries[:10]], # Limit for throughput test
|
| 291 |
-
max_workers=2
|
| 292 |
-
)
|
| 293 |
-
|
| 294 |
-
return {
|
| 295 |
-
'latency': latency_stats,
|
| 296 |
-
'throughput': throughput_stats
|
| 297 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
exp_pipeline/pipeline.py
DELETED
|
@@ -1,56 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
End-to-end pipeline for dataset download, preprocessing, embedding, and indexing.
|
| 3 |
-
"""
|
| 4 |
-
import logging
|
| 5 |
-
from data_processing.data_loader import DataLoader
|
| 6 |
-
from data_processing.preprocessor import Preprocessor
|
| 7 |
-
from retriever.embedder import Embedder
|
| 8 |
-
from retriever.faiss_index import build_faiss_index
|
| 9 |
-
|
| 10 |
-
logger = logging.getLogger(__name__)
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def run_pipeline(split: str = "train"):
|
| 14 |
-
# 1. 下载MS MARCO Passage Ranking数据集
|
| 15 |
-
data_loader = DataLoader()
|
| 16 |
-
raw_data = data_loader.get_passage_dataset(split)
|
| 17 |
-
logger.info(f"Loaded {len(raw_data)} samples from MS MARCO Passage Ranking [{split}]")
|
| 18 |
-
print("data_loader\n")
|
| 19 |
-
|
| 20 |
-
# 2. 预处理数据
|
| 21 |
-
preprocessor = Preprocessor()
|
| 22 |
-
# HuggingFace datasets对象转list
|
| 23 |
-
if hasattr(raw_data, "to_dict"):
|
| 24 |
-
raw_data = raw_data.to_dict()
|
| 25 |
-
raw_data = [dict(zip(raw_data.keys(), v)) for v in zip(*raw_data.values())]
|
| 26 |
-
print("raw_data\n")
|
| 27 |
-
|
| 28 |
-
# MS MARCO Passage v2.1: 用passages["passage_text"]字段
|
| 29 |
-
passages = []
|
| 30 |
-
for item in raw_data:
|
| 31 |
-
if "passages" in item and "passage_text" in item["passages"]:
|
| 32 |
-
passages.extend(item["passages"]["passage_text"])
|
| 33 |
-
processed = preprocessor.preprocess_passages(passages)
|
| 34 |
-
texts = [p["text"] for p in processed]
|
| 35 |
-
print("texts\n")
|
| 36 |
-
|
| 37 |
-
logger.info(f"Processed {len(texts)} passages")
|
| 38 |
-
|
| 39 |
-
# 3. 生产embedding
|
| 40 |
-
embedder = Embedder(device="cuda")
|
| 41 |
-
embeddings = embedder.encode(texts)
|
| 42 |
-
print(f"Embedding shape: {getattr(embeddings, 'shape', None)}")
|
| 43 |
-
print(f"Texts count: {len(texts)}")
|
| 44 |
-
if embeddings is None or not hasattr(embeddings, 'shape') or len(embeddings.shape) != 2 or embeddings.shape[0] == 0:
|
| 45 |
-
raise ValueError("Embeddings is empty or not a 2D array. Check input texts and embedding model.")
|
| 46 |
-
|
| 47 |
-
# 4. 建立FAISS索引
|
| 48 |
-
index = build_faiss_index(embeddings, texts, index_type="HNSW")
|
| 49 |
-
logger.info("FAISS index built successfully")
|
| 50 |
-
# 持久化index到./index文件夹
|
| 51 |
-
index.save("../index/msmarco_hnsw")
|
| 52 |
-
logger.info("FAISS index saved to ./index/msmarco_hnsw")
|
| 53 |
-
return index
|
| 54 |
-
|
| 55 |
-
if __name__ == "__main__":
|
| 56 |
-
run_pipeline("train")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generator/__init__.py
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
from .vllm_server import VLLMServer
|
| 2 |
-
from .safe_generate import SafeGenerator
|
| 3 |
-
from .prompt_templates import PromptTemplates
|
| 4 |
-
|
| 5 |
-
__all__ = ['VLLMServer', 'SafeGenerator', 'PromptTemplates']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generator/prompt_templates.py
DELETED
|
@@ -1,113 +0,0 @@
|
|
| 1 |
-
from typing import List, Dict, Any
|
| 2 |
-
from dataclasses import dataclass
|
| 3 |
-
|
| 4 |
-
@dataclass
|
| 5 |
-
class PromptTemplate:
|
| 6 |
-
name: str
|
| 7 |
-
template: str
|
| 8 |
-
system_prompt: str = ""
|
| 9 |
-
|
| 10 |
-
class PromptTemplates:
|
| 11 |
-
def __init__(self):
|
| 12 |
-
self.templates = {
|
| 13 |
-
'rag': PromptTemplate(
|
| 14 |
-
name='rag',
|
| 15 |
-
system_prompt="You are a helpful assistant that answers questions based on provided context. Always cite your sources when possible.",
|
| 16 |
-
template="""Context:
|
| 17 |
-
{context}
|
| 18 |
-
|
| 19 |
-
Question: {question}
|
| 20 |
-
|
| 21 |
-
Answer:"""
|
| 22 |
-
),
|
| 23 |
-
|
| 24 |
-
'rag_with_citations': PromptTemplate(
|
| 25 |
-
name='rag_with_citations',
|
| 26 |
-
system_prompt="You are a helpful assistant that answers questions based on provided context. Always provide citations in the format [1], [2], etc.",
|
| 27 |
-
template="""Context:
|
| 28 |
-
{context}
|
| 29 |
-
|
| 30 |
-
Question: {question}
|
| 31 |
-
|
| 32 |
-
Answer (with citations):"""
|
| 33 |
-
),
|
| 34 |
-
|
| 35 |
-
'rag_safe': PromptTemplate(
|
| 36 |
-
name='rag_safe',
|
| 37 |
-
system_prompt="You are a helpful assistant that answers questions based on provided context. If you're uncertain, say so. Always cite your sources.",
|
| 38 |
-
template="""Context:
|
| 39 |
-
{context}
|
| 40 |
-
|
| 41 |
-
Question: {question}
|
| 42 |
-
|
| 43 |
-
Instructions:
|
| 44 |
-
- Answer based on the provided context
|
| 45 |
-
- If uncertain, express your uncertainty
|
| 46 |
-
- Always provide citations
|
| 47 |
-
- If the context doesn't contain enough information, say so
|
| 48 |
-
|
| 49 |
-
Answer:"""
|
| 50 |
-
),
|
| 51 |
-
|
| 52 |
-
'rag_uncertain': PromptTemplate(
|
| 53 |
-
name='rag_uncertain',
|
| 54 |
-
system_prompt="You are a helpful assistant. Express uncertainty when appropriate and always cite sources.",
|
| 55 |
-
template="""Context:
|
| 56 |
-
{context}
|
| 57 |
-
|
| 58 |
-
Question: {question}
|
| 59 |
-
|
| 60 |
-
Answer (express uncertainty if appropriate):"""
|
| 61 |
-
)
|
| 62 |
-
}
|
| 63 |
-
|
| 64 |
-
def get_template(self, name: str) -> PromptTemplate:
|
| 65 |
-
"""Get a prompt template by name"""
|
| 66 |
-
if name not in self.templates:
|
| 67 |
-
raise ValueError(f"Unknown template: {name}")
|
| 68 |
-
return self.templates[name]
|
| 69 |
-
|
| 70 |
-
def format_prompt(self, template_name: str, **kwargs) -> str:
|
| 71 |
-
"""Format a prompt using a template"""
|
| 72 |
-
template = self.get_template(template_name)
|
| 73 |
-
|
| 74 |
-
# Format the main template
|
| 75 |
-
formatted = template.template.format(**kwargs)
|
| 76 |
-
|
| 77 |
-
# Add system prompt if available
|
| 78 |
-
if template.system_prompt:
|
| 79 |
-
formatted = f"{template.system_prompt}\n\n{formatted}"
|
| 80 |
-
|
| 81 |
-
return formatted
|
| 82 |
-
|
| 83 |
-
def format_context(self, retrieved_passages: List[Dict[str, Any]],
|
| 84 |
-
max_length: int = 2000) -> str:
|
| 85 |
-
"""Format retrieved passages as context"""
|
| 86 |
-
context_parts = []
|
| 87 |
-
current_length = 0
|
| 88 |
-
|
| 89 |
-
for i, passage in enumerate(retrieved_passages):
|
| 90 |
-
text = passage.get('text', '')
|
| 91 |
-
if current_length + len(text) > max_length:
|
| 92 |
-
break
|
| 93 |
-
|
| 94 |
-
context_parts.append(f"[{i+1}] {text}")
|
| 95 |
-
current_length += len(text)
|
| 96 |
-
|
| 97 |
-
return "\n\n".join(context_parts)
|
| 98 |
-
|
| 99 |
-
def create_rag_prompt(self, question: str, retrieved_passages: List[Dict[str, Any]],
|
| 100 |
-
template_name: str = 'rag', max_context_length: int = 2000) -> str:
|
| 101 |
-
"""Create a RAG prompt"""
|
| 102 |
-
context = self.format_context(retrieved_passages, max_context_length)
|
| 103 |
-
return self.format_prompt(template_name, question=question, context=context)
|
| 104 |
-
|
| 105 |
-
def create_batch_prompts(self, questions: List[str],
|
| 106 |
-
retrieved_passages_list: List[List[Dict[str, Any]]],
|
| 107 |
-
template_name: str = 'rag') -> List[str]:
|
| 108 |
-
"""Create multiple RAG prompts"""
|
| 109 |
-
prompts = []
|
| 110 |
-
for question, passages in zip(questions, retrieved_passages_list):
|
| 111 |
-
prompt = self.create_rag_prompt(question, passages, template_name)
|
| 112 |
-
prompts.append(prompt)
|
| 113 |
-
return prompts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generator/safe_generate.py
DELETED
|
@@ -1,170 +0,0 @@
|
|
| 1 |
-
from typing import List, Dict, Any, Optional, Tuple
|
| 2 |
-
import logging
|
| 3 |
-
from .vllm_server import VLLMServer
|
| 4 |
-
from .prompt_templates import PromptTemplates
|
| 5 |
-
from ..calibration.features import RiskFeatureExtractor
|
| 6 |
-
|
| 7 |
-
logger = logging.getLogger(__name__)
|
| 8 |
-
|
| 9 |
-
class SafeGenerator:
|
| 10 |
-
def __init__(self, vllm_server: VLLMServer,
|
| 11 |
-
risk_extractor: RiskFeatureExtractor,
|
| 12 |
-
tau1: float = 0.3, tau2: float = 0.7):
|
| 13 |
-
self.vllm_server = vllm_server
|
| 14 |
-
self.risk_extractor = risk_extractor
|
| 15 |
-
self.prompt_templates = PromptTemplates()
|
| 16 |
-
self.tau1 = tau1 # Low risk threshold
|
| 17 |
-
self.tau2 = tau2 # High risk threshold
|
| 18 |
-
|
| 19 |
-
def generate_with_strategy(self, question: str,
|
| 20 |
-
retrieved_passages: List[Dict[str, Any]],
|
| 21 |
-
force_citation: bool = False) -> Dict[str, Any]:
|
| 22 |
-
"""Generate answer with adaptive strategy based on risk assessment"""
|
| 23 |
-
|
| 24 |
-
# Extract risk features
|
| 25 |
-
risk_features = self.risk_extractor.extract_features(
|
| 26 |
-
question, retrieved_passages
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
-
# Get risk score (placeholder - will be implemented in calibration module)
|
| 30 |
-
risk_score = self._estimate_risk_score(risk_features)
|
| 31 |
-
|
| 32 |
-
# Determine strategy based on risk score
|
| 33 |
-
if risk_score < self.tau1:
|
| 34 |
-
# Low risk: normal generation
|
| 35 |
-
strategy = "normal"
|
| 36 |
-
temperature = 0.7
|
| 37 |
-
template_name = "rag"
|
| 38 |
-
elif risk_score < self.tau2:
|
| 39 |
-
# Medium risk: conservative generation with citations
|
| 40 |
-
strategy = "conservative"
|
| 41 |
-
temperature = 0.5
|
| 42 |
-
template_name = "rag_with_citations"
|
| 43 |
-
force_citation = True
|
| 44 |
-
else:
|
| 45 |
-
# High risk: very conservative or refuse
|
| 46 |
-
strategy = "conservative_or_refuse"
|
| 47 |
-
temperature = 0.3
|
| 48 |
-
template_name = "rag_safe"
|
| 49 |
-
force_citation = True
|
| 50 |
-
|
| 51 |
-
# Generate prompt
|
| 52 |
-
prompt = self.prompt_templates.create_rag_prompt(
|
| 53 |
-
question, retrieved_passages, template_name
|
| 54 |
-
)
|
| 55 |
-
|
| 56 |
-
# Generate answer
|
| 57 |
-
try:
|
| 58 |
-
result = self.vllm_server.generate_single(
|
| 59 |
-
prompt,
|
| 60 |
-
max_tokens=512,
|
| 61 |
-
temperature=temperature
|
| 62 |
-
)
|
| 63 |
-
|
| 64 |
-
# Post-process for citations if needed
|
| 65 |
-
if force_citation:
|
| 66 |
-
result = self._add_citations(result, retrieved_passages)
|
| 67 |
-
|
| 68 |
-
return {
|
| 69 |
-
'answer': result,
|
| 70 |
-
'risk_score': risk_score,
|
| 71 |
-
'strategy': strategy,
|
| 72 |
-
'temperature': temperature,
|
| 73 |
-
'features': risk_features,
|
| 74 |
-
'citations': self._extract_citations(result, retrieved_passages)
|
| 75 |
-
}
|
| 76 |
-
|
| 77 |
-
except Exception as e:
|
| 78 |
-
logger.error(f"Generation failed: {e}")
|
| 79 |
-
return {
|
| 80 |
-
'answer': "I apologize, but I encountered an error while generating a response.",
|
| 81 |
-
'risk_score': 1.0,
|
| 82 |
-
'strategy': 'error',
|
| 83 |
-
'temperature': 0.0,
|
| 84 |
-
'features': risk_features,
|
| 85 |
-
'citations': []
|
| 86 |
-
}
|
| 87 |
-
|
| 88 |
-
def generate_batch(self, questions: List[str],
|
| 89 |
-
retrieved_passages_list: List[List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
|
| 90 |
-
"""Generate answers for multiple questions"""
|
| 91 |
-
results = []
|
| 92 |
-
|
| 93 |
-
for question, passages in zip(questions, retrieved_passages_list):
|
| 94 |
-
result = self.generate_with_strategy(question, passages)
|
| 95 |
-
results.append(result)
|
| 96 |
-
|
| 97 |
-
return results
|
| 98 |
-
|
| 99 |
-
def _estimate_risk_score(self, features: Dict[str, Any]) -> float:
|
| 100 |
-
"""Estimate risk score from features (placeholder implementation)"""
|
| 101 |
-
# This is a simplified risk estimation
|
| 102 |
-
# In practice, this would use a trained calibration model
|
| 103 |
-
|
| 104 |
-
# Higher similarity scores = lower risk
|
| 105 |
-
avg_similarity = features.get('avg_similarity', 0.5)
|
| 106 |
-
|
| 107 |
-
# More diverse passages = lower risk
|
| 108 |
-
diversity = features.get('diversity', 0.5)
|
| 109 |
-
|
| 110 |
-
# More passages = lower risk (up to a point)
|
| 111 |
-
num_passages = min(features.get('num_passages', 1), 10)
|
| 112 |
-
passage_score = 1.0 - (num_passages / 10.0)
|
| 113 |
-
|
| 114 |
-
# Combine factors
|
| 115 |
-
risk_score = 1.0 - (avg_similarity * 0.4 + diversity * 0.3 + (1.0 - passage_score) * 0.3)
|
| 116 |
-
|
| 117 |
-
return max(0.0, min(1.0, risk_score))
|
| 118 |
-
|
| 119 |
-
def _add_citations(self, answer: str, passages: List[Dict[str, Any]]) -> str:
|
| 120 |
-
"""Add citations to answer if not present"""
|
| 121 |
-
if '[' in answer and ']' in answer:
|
| 122 |
-
return answer # Already has citations
|
| 123 |
-
|
| 124 |
-
# Simple citation addition (in practice, use more sophisticated methods)
|
| 125 |
-
cited_answer = answer
|
| 126 |
-
for i, passage in enumerate(passages[:3]): # Limit to first 3 passages
|
| 127 |
-
if any(word in answer.lower() for word in passage['text'].lower().split()[:5]):
|
| 128 |
-
cited_answer += f" [{i+1}]"
|
| 129 |
-
|
| 130 |
-
return cited_answer
|
| 131 |
-
|
| 132 |
-
def _extract_citations(self, answer: str, passages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 133 |
-
"""Extract citations from answer"""
|
| 134 |
-
citations = []
|
| 135 |
-
|
| 136 |
-
# Find citation markers like [1], [2], etc.
|
| 137 |
-
import re
|
| 138 |
-
citation_matches = re.findall(r'\[(\d+)\]', answer)
|
| 139 |
-
|
| 140 |
-
for match in citation_matches:
|
| 141 |
-
idx = int(match) - 1
|
| 142 |
-
if 0 <= idx < len(passages):
|
| 143 |
-
citations.append({
|
| 144 |
-
'id': idx,
|
| 145 |
-
'text': passages[idx]['text'],
|
| 146 |
-
'metadata': passages[idx].get('metadata', {})
|
| 147 |
-
})
|
| 148 |
-
|
| 149 |
-
return citations
|
| 150 |
-
|
| 151 |
-
def get_generation_stats(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 152 |
-
"""Get statistics from generation results"""
|
| 153 |
-
if not results:
|
| 154 |
-
return {}
|
| 155 |
-
|
| 156 |
-
risk_scores = [r['risk_score'] for r in results]
|
| 157 |
-
strategies = [r['strategy'] for r in results]
|
| 158 |
-
|
| 159 |
-
strategy_counts = {}
|
| 160 |
-
for strategy in strategies:
|
| 161 |
-
strategy_counts[strategy] = strategy_counts.get(strategy, 0) + 1
|
| 162 |
-
|
| 163 |
-
return {
|
| 164 |
-
'num_queries': len(results),
|
| 165 |
-
'avg_risk_score': sum(risk_scores) / len(risk_scores),
|
| 166 |
-
'min_risk_score': min(risk_scores),
|
| 167 |
-
'max_risk_score': max(risk_scores),
|
| 168 |
-
'strategy_distribution': strategy_counts,
|
| 169 |
-
'avg_citations_per_answer': sum(len(r.get('citations', [])) for r in results) / len(results)
|
| 170 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generator/vllm_server.py
DELETED
|
@@ -1,102 +0,0 @@
|
|
| 1 |
-
from vllm import LLM, SamplingParams
|
| 2 |
-
from typing import List, Dict, Any, Optional
|
| 3 |
-
import logging
|
| 4 |
-
import asyncio
|
| 5 |
-
from concurrent.futures import ThreadPoolExecutor
|
| 6 |
-
|
| 7 |
-
logger = logging.getLogger(__name__)
|
| 8 |
-
|
| 9 |
-
class VLLMServer:
|
| 10 |
-
def __init__(self, model_name: str = "openai/gpt-oss-20b",
|
| 11 |
-
tensor_parallel_size: int = 1, gpu_memory_utilization: float = 0.9):
|
| 12 |
-
self.model_name = model_name
|
| 13 |
-
self.tensor_parallel_size = tensor_parallel_size
|
| 14 |
-
self.gpu_memory_utilization = gpu_memory_utilization
|
| 15 |
-
self.llm = None
|
| 16 |
-
self.executor = ThreadPoolExecutor(max_workers=4)
|
| 17 |
-
|
| 18 |
-
def initialize(self):
|
| 19 |
-
"""Initialize the vLLM model"""
|
| 20 |
-
try:
|
| 21 |
-
self.llm = LLM(
|
| 22 |
-
model=self.model_name,
|
| 23 |
-
tensor_parallel_size=self.tensor_parallel_size,
|
| 24 |
-
gpu_memory_utilization=self.gpu_memory_utilization,
|
| 25 |
-
trust_remote_code=True
|
| 26 |
-
)
|
| 27 |
-
logger.info(f"Initialized vLLM with model: {self.model_name}")
|
| 28 |
-
except Exception as e:
|
| 29 |
-
logger.error(f"Failed to initialize vLLM: {e}")
|
| 30 |
-
raise
|
| 31 |
-
|
| 32 |
-
def generate(self, prompts: List[str],
|
| 33 |
-
max_tokens: int = 512,
|
| 34 |
-
temperature: float = 0.7,
|
| 35 |
-
top_p: float = 0.9,
|
| 36 |
-
stop: Optional[List[str]] = None) -> List[Dict[str, Any]]:
|
| 37 |
-
"""Generate text for prompts"""
|
| 38 |
-
if self.llm is None:
|
| 39 |
-
self.initialize()
|
| 40 |
-
|
| 41 |
-
sampling_params = SamplingParams(
|
| 42 |
-
max_tokens=max_tokens,
|
| 43 |
-
temperature=temperature,
|
| 44 |
-
top_p=top_p,
|
| 45 |
-
stop=stop
|
| 46 |
-
)
|
| 47 |
-
|
| 48 |
-
try:
|
| 49 |
-
outputs = self.llm.generate(prompts, sampling_params)
|
| 50 |
-
|
| 51 |
-
results = []
|
| 52 |
-
for output in outputs:
|
| 53 |
-
results.append({
|
| 54 |
-
'text': output.outputs[0].text,
|
| 55 |
-
'prompt': output.prompt,
|
| 56 |
-
'finish_reason': output.outputs[0].finish_reason,
|
| 57 |
-
'token_ids': output.outputs[0].token_ids,
|
| 58 |
-
'logprobs': getattr(output.outputs[0], 'logprobs', None)
|
| 59 |
-
})
|
| 60 |
-
|
| 61 |
-
return results
|
| 62 |
-
except Exception as e:
|
| 63 |
-
logger.error(f"Generation failed: {e}")
|
| 64 |
-
raise
|
| 65 |
-
|
| 66 |
-
def generate_single(self, prompt: str, **kwargs) -> str:
|
| 67 |
-
"""Generate text for a single prompt"""
|
| 68 |
-
results = self.generate([prompt], **kwargs)
|
| 69 |
-
return results[0]['text'] if results else ""
|
| 70 |
-
|
| 71 |
-
def generate_batch(self, prompts: List[str], batch_size: int = 8, **kwargs) -> List[str]:
|
| 72 |
-
"""Generate text for multiple prompts in batches"""
|
| 73 |
-
all_results = []
|
| 74 |
-
|
| 75 |
-
for i in range(0, len(prompts), batch_size):
|
| 76 |
-
batch_prompts = prompts[i:i + batch_size]
|
| 77 |
-
batch_results = self.generate(batch_prompts, **kwargs)
|
| 78 |
-
all_results.extend([r['text'] for r in batch_results])
|
| 79 |
-
|
| 80 |
-
return all_results
|
| 81 |
-
|
| 82 |
-
async def generate_async(self, prompts: List[str], **kwargs) -> List[Dict[str, Any]]:
|
| 83 |
-
"""Async generation"""
|
| 84 |
-
loop = asyncio.get_event_loop()
|
| 85 |
-
return await loop.run_in_executor(self.executor, self.generate, prompts, **kwargs)
|
| 86 |
-
|
| 87 |
-
def get_model_info(self) -> Dict[str, Any]:
|
| 88 |
-
"""Get model information"""
|
| 89 |
-
if self.llm is None:
|
| 90 |
-
return {}
|
| 91 |
-
|
| 92 |
-
return {
|
| 93 |
-
'model_name': self.model_name,
|
| 94 |
-
'tensor_parallel_size': self.tensor_parallel_size,
|
| 95 |
-
'gpu_memory_utilization': self.gpu_memory_utilization,
|
| 96 |
-
'is_initialized': self.llm is not None
|
| 97 |
-
}
|
| 98 |
-
|
| 99 |
-
def cleanup(self):
|
| 100 |
-
"""Cleanup resources"""
|
| 101 |
-
if self.executor:
|
| 102 |
-
self.executor.shutdown(wait=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
real_embedding_test.py
DELETED
|
@@ -1,269 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
"""
|
| 4 |
-
SafeRAG Real Embedding Test
|
| 5 |
-
Load data -> Generate real embeddings using sentence-transformers -> Build index -> Retrieve
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
import sys
|
| 9 |
-
import os
|
| 10 |
-
import time
|
| 11 |
-
import numpy as np
|
| 12 |
-
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 13 |
-
|
| 14 |
-
def test_real_embedding_pipeline():
|
| 15 |
-
"""Test the complete pipeline with real embeddings"""
|
| 16 |
-
print("SafeRAG Real Embedding Pipeline Test")
|
| 17 |
-
print("=" * 50)
|
| 18 |
-
|
| 19 |
-
try:
|
| 20 |
-
# Step 1: Load data
|
| 21 |
-
print("\n1. Loading data...")
|
| 22 |
-
from data_processing import DataLoader, Preprocessor
|
| 23 |
-
|
| 24 |
-
loader = DataLoader()
|
| 25 |
-
preprocessor = Preprocessor()
|
| 26 |
-
|
| 27 |
-
# Load knowledge base
|
| 28 |
-
kb_passages = loader.get_knowledge_base()
|
| 29 |
-
print(f" ✓ Loaded {len(kb_passages)} knowledge base passages")
|
| 30 |
-
|
| 31 |
-
# Show sample passages
|
| 32 |
-
for i, passage in enumerate(kb_passages):
|
| 33 |
-
print(f" [{i+1}] {passage}")
|
| 34 |
-
|
| 35 |
-
# Preprocess passages
|
| 36 |
-
processed_passages = preprocessor.preprocess_passages(kb_passages)
|
| 37 |
-
print(f" ✓ Preprocessed {len(processed_passages)} passages")
|
| 38 |
-
|
| 39 |
-
# Step 2: Generate real embeddings
|
| 40 |
-
print("\n2. Generating real embeddings with sentence-transformers...")
|
| 41 |
-
from retriever import Embedder
|
| 42 |
-
|
| 43 |
-
# Use a smaller model for faster testing
|
| 44 |
-
embedder = Embedder(model_name="all-MiniLM-L6-v2", device="cpu")
|
| 45 |
-
print(f" ✓ Loaded embedding model: {embedder.model_name}")
|
| 46 |
-
print(f" ✓ Embedding dimension: {embedder.get_dimension()}")
|
| 47 |
-
|
| 48 |
-
# Extract text from processed passages
|
| 49 |
-
passage_texts = [p['text'] for p in processed_passages]
|
| 50 |
-
|
| 51 |
-
# Generate embeddings
|
| 52 |
-
start_time = time.time()
|
| 53 |
-
embeddings = embedder.encode_passages(passage_texts)
|
| 54 |
-
embedding_time = time.time() - start_time
|
| 55 |
-
|
| 56 |
-
print(f" ✓ Generated {embeddings.shape[0]} embeddings in {embedding_time:.3f}s")
|
| 57 |
-
print(f" ✓ Embedding shape: {embeddings.shape}")
|
| 58 |
-
print(f" ✓ Embedding type: {type(embeddings)}")
|
| 59 |
-
|
| 60 |
-
# Show embedding statistics
|
| 61 |
-
print(f" ✓ Embedding stats:")
|
| 62 |
-
print(f" - Mean: {np.mean(embeddings):.4f}")
|
| 63 |
-
print(f" - Std: {np.std(embeddings):.4f}")
|
| 64 |
-
print(f" - Min: {np.min(embeddings):.4f}")
|
| 65 |
-
print(f" - Max: {np.max(embeddings):.4f}")
|
| 66 |
-
|
| 67 |
-
# Step 3: Build FAISS index
|
| 68 |
-
print("\n3. Building FAISS index...")
|
| 69 |
-
from retriever import FAISSIndex
|
| 70 |
-
|
| 71 |
-
index = FAISSIndex(embedder.get_dimension())
|
| 72 |
-
start_time = time.time()
|
| 73 |
-
index.build_index(embeddings, passage_texts)
|
| 74 |
-
build_time = time.time() - start_time
|
| 75 |
-
|
| 76 |
-
print(f" ✓ Built FAISS index in {build_time:.3f}s")
|
| 77 |
-
print(f" ✓ Index contains {index.index.ntotal} vectors")
|
| 78 |
-
|
| 79 |
-
# Step 4: Test retrieval
|
| 80 |
-
print("\n4. Testing retrieval...")
|
| 81 |
-
from retriever import Retriever
|
| 82 |
-
|
| 83 |
-
retriever = Retriever(embedder, index, None) # No reranker for simplicity
|
| 84 |
-
|
| 85 |
-
test_queries = [
|
| 86 |
-
"What is machine learning?",
|
| 87 |
-
"Tell me about the capital of France",
|
| 88 |
-
"How does Python work?",
|
| 89 |
-
"What is artificial intelligence?"
|
| 90 |
-
]
|
| 91 |
-
|
| 92 |
-
for query in test_queries:
|
| 93 |
-
print(f"\n Query: '{query}'")
|
| 94 |
-
start_time = time.time()
|
| 95 |
-
results = retriever.retrieve_single(query, k=3)
|
| 96 |
-
retrieval_time = time.time() - start_time
|
| 97 |
-
|
| 98 |
-
print(f" ✓ Retrieved {len(results)} passages in {retrieval_time:.3f}s")
|
| 99 |
-
for i, result in enumerate(results):
|
| 100 |
-
print(f" [{i+1}] Score: {result['score']:.4f}")
|
| 101 |
-
print(f" Text: {result['text'][:100]}...")
|
| 102 |
-
|
| 103 |
-
# Step 5: Test similarity calculation
|
| 104 |
-
print("\n5. Testing similarity calculation...")
|
| 105 |
-
|
| 106 |
-
# Test query-passage similarity
|
| 107 |
-
query = "What is machine learning?"
|
| 108 |
-
query_embedding = embedder.encode_queries([query])[0]
|
| 109 |
-
|
| 110 |
-
print(f" Query: '{query}'")
|
| 111 |
-
print(f" Query embedding shape: {query_embedding.shape}")
|
| 112 |
-
|
| 113 |
-
# Calculate similarities with all passages
|
| 114 |
-
similarities = []
|
| 115 |
-
for i, passage_embedding in enumerate(embeddings):
|
| 116 |
-
# Cosine similarity
|
| 117 |
-
similarity = np.dot(query_embedding, passage_embedding) / (
|
| 118 |
-
np.linalg.norm(query_embedding) * np.linalg.norm(passage_embedding)
|
| 119 |
-
)
|
| 120 |
-
similarities.append((i, similarity, passage_texts[i]))
|
| 121 |
-
|
| 122 |
-
# Sort by similarity
|
| 123 |
-
similarities.sort(key=lambda x: x[1], reverse=True)
|
| 124 |
-
|
| 125 |
-
print(f" ✓ Calculated similarities with {len(similarities)} passages")
|
| 126 |
-
print(f" Top 3 most similar passages:")
|
| 127 |
-
for i, (idx, sim, text) in enumerate(similarities[:3]):
|
| 128 |
-
print(f" [{i+1}] Similarity: {sim:.4f}")
|
| 129 |
-
print(f" Text: {text[:80]}...")
|
| 130 |
-
|
| 131 |
-
# Step 6: Test generation
|
| 132 |
-
print("\n6. Testing generation...")
|
| 133 |
-
from generator import SafeGenerator, PromptTemplates
|
| 134 |
-
|
| 135 |
-
templates = PromptTemplates()
|
| 136 |
-
generator = SafeGenerator(None, None, 0.3, 0.7) # Simplified version
|
| 137 |
-
|
| 138 |
-
test_query = "What is machine learning?"
|
| 139 |
-
retrieved_passages = retriever.retrieve_single(test_query, k=3)
|
| 140 |
-
|
| 141 |
-
print(f" Query: '{test_query}'")
|
| 142 |
-
print(f" Retrieved {len(retrieved_passages)} passages")
|
| 143 |
-
|
| 144 |
-
# Generate answer
|
| 145 |
-
start_time = time.time()
|
| 146 |
-
result = generator.generate_with_strategy(test_query, retrieved_passages)
|
| 147 |
-
generation_time = time.time() - start_time
|
| 148 |
-
|
| 149 |
-
print(f" ✓ Generated answer in {generation_time:.3f}s")
|
| 150 |
-
print(f" Answer: {result['answer'][:200]}...")
|
| 151 |
-
print(f" Risk Score: {result['risk_score']:.3f}")
|
| 152 |
-
print(f" Strategy: {result['strategy']}")
|
| 153 |
-
|
| 154 |
-
print("\n" + "=" * 50)
|
| 155 |
-
print("🎉 Real embedding pipeline test completed successfully!")
|
| 156 |
-
print("\nPipeline Summary:")
|
| 157 |
-
print(f"- Data Loading: {len(kb_passages)} passages")
|
| 158 |
-
print(f"- Real Embedding Generation: {embeddings.shape[0]} vectors ({embeddings.shape[1]}D)")
|
| 159 |
-
print(f"- Index Building: {index.index.ntotal} indexed vectors")
|
| 160 |
-
print(f"- Retrieval: {len(test_queries)} test queries")
|
| 161 |
-
print(f"- Similarity Calculation: Cosine similarity with all passages")
|
| 162 |
-
print(f"- Generation: Risk-aware answer generation")
|
| 163 |
-
|
| 164 |
-
return True
|
| 165 |
-
|
| 166 |
-
except Exception as e:
|
| 167 |
-
print(f"\n❌ Pipeline test failed: {e}")
|
| 168 |
-
import traceback
|
| 169 |
-
traceback.print_exc()
|
| 170 |
-
return False
|
| 171 |
-
|
| 172 |
-
def test_embedding_quality():
|
| 173 |
-
"""Test embedding quality and properties"""
|
| 174 |
-
print("\n" + "=" * 50)
|
| 175 |
-
print("Testing Embedding Quality")
|
| 176 |
-
print("=" * 50)
|
| 177 |
-
|
| 178 |
-
try:
|
| 179 |
-
from retriever import Embedder
|
| 180 |
-
|
| 181 |
-
# Initialize embedder
|
| 182 |
-
embedder = Embedder(model_name="all-MiniLM-L6-v2", device="cpu")
|
| 183 |
-
|
| 184 |
-
# Test texts
|
| 185 |
-
test_texts = [
|
| 186 |
-
"Machine learning is a subset of artificial intelligence",
|
| 187 |
-
"The capital of France is Paris",
|
| 188 |
-
"Python is a programming language",
|
| 189 |
-
"Machine learning algorithms learn from data", # Similar to first
|
| 190 |
-
"Paris is the capital city of France", # Similar to second
|
| 191 |
-
]
|
| 192 |
-
|
| 193 |
-
print("1. Generating embeddings for test texts...")
|
| 194 |
-
embeddings = embedder.encode(test_texts)
|
| 195 |
-
print(f" ✓ Generated {embeddings.shape[0]} embeddings")
|
| 196 |
-
|
| 197 |
-
print("\n2. Testing similarity between related texts...")
|
| 198 |
-
|
| 199 |
-
# Test similarity between related texts
|
| 200 |
-
pairs = [
|
| 201 |
-
(0, 3, "Machine learning texts"),
|
| 202 |
-
(1, 4, "France/Paris texts"),
|
| 203 |
-
]
|
| 204 |
-
|
| 205 |
-
for i, j, description in pairs:
|
| 206 |
-
sim = np.dot(embeddings[i], embeddings[j]) / (
|
| 207 |
-
np.linalg.norm(embeddings[i]) * np.linalg.norm(embeddings[j])
|
| 208 |
-
)
|
| 209 |
-
print(f" {description}: {sim:.4f}")
|
| 210 |
-
print(f" Text 1: {test_texts[i]}")
|
| 211 |
-
print(f" Text 2: {test_texts[j]}")
|
| 212 |
-
|
| 213 |
-
print("\n3. Testing embedding properties...")
|
| 214 |
-
|
| 215 |
-
# Check if embeddings are normalized
|
| 216 |
-
norms = [np.linalg.norm(emb) for emb in embeddings]
|
| 217 |
-
print(f" ✓ Embedding norms: {[f'{n:.4f}' for n in norms]}")
|
| 218 |
-
|
| 219 |
-
# Check embedding statistics
|
| 220 |
-
all_embeddings = embeddings.flatten()
|
| 221 |
-
print(f" ✓ All embedding values:")
|
| 222 |
-
print(f" - Mean: {np.mean(all_embeddings):.4f}")
|
| 223 |
-
print(f" - Std: {np.std(all_embeddings):.4f}")
|
| 224 |
-
print(f" - Min: {np.min(all_embeddings):.4f}")
|
| 225 |
-
print(f" - Max: {np.max(all_embeddings):.4f}")
|
| 226 |
-
|
| 227 |
-
print("\n✅ Embedding quality test completed!")
|
| 228 |
-
return True
|
| 229 |
-
|
| 230 |
-
except Exception as e:
|
| 231 |
-
print(f"\n❌ Embedding quality test failed: {e}")
|
| 232 |
-
import traceback
|
| 233 |
-
traceback.print_exc()
|
| 234 |
-
return False
|
| 235 |
-
|
| 236 |
-
def main():
|
| 237 |
-
"""Run all tests"""
|
| 238 |
-
print("SafeRAG Real Embedding Test Suite")
|
| 239 |
-
print("=" * 60)
|
| 240 |
-
|
| 241 |
-
success = True
|
| 242 |
-
|
| 243 |
-
# Test embedding quality
|
| 244 |
-
if not test_embedding_quality():
|
| 245 |
-
success = False
|
| 246 |
-
|
| 247 |
-
# Test real embedding pipeline
|
| 248 |
-
if not test_real_embedding_pipeline():
|
| 249 |
-
success = False
|
| 250 |
-
|
| 251 |
-
print("\n" + "=" * 60)
|
| 252 |
-
if success:
|
| 253 |
-
print("🎉 All real embedding tests passed!")
|
| 254 |
-
print("\nThe system can now:")
|
| 255 |
-
print("1. ✅ Load data from knowledge base")
|
| 256 |
-
print("2. ✅ Generate real embeddings using sentence-transformers")
|
| 257 |
-
print("3. ✅ Build FAISS index with real embeddings")
|
| 258 |
-
print("4. ✅ Retrieve relevant passages using real similarity")
|
| 259 |
-
print("5. ✅ Calculate cosine similarity between queries and passages")
|
| 260 |
-
print("6. ✅ Generate answers based on retrieved passages")
|
| 261 |
-
print("7. ✅ Assess embedding quality and properties")
|
| 262 |
-
else:
|
| 263 |
-
print("❌ Some tests failed. Please check the errors above.")
|
| 264 |
-
|
| 265 |
-
return success
|
| 266 |
-
|
| 267 |
-
if __name__ == "__main__":
|
| 268 |
-
success = main()
|
| 269 |
-
sys.exit(0 if success else 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
torch>=2.0.0
|
| 2 |
-
transformers>=4.35.0
|
| 3 |
-
datasets>=2.14.0
|
| 4 |
-
vllm>=0.2.0
|
| 5 |
-
faiss-cpu>=1.7.4
|
| 6 |
-
sentence-transformers>=2.2.2
|
| 7 |
-
scikit-learn>=1.3.0
|
| 8 |
-
numpy>=1.24.0
|
| 9 |
-
pandas>=2.0.0
|
| 10 |
-
tqdm>=4.65.0
|
| 11 |
-
gradio>=4.0.0
|
| 12 |
-
accelerate>=0.24.0
|
| 13 |
-
evaluate>=0.4.0
|
| 14 |
-
rouge-score>=0.1.2
|
| 15 |
-
nltk>=3.8.0
|
| 16 |
-
spacy>=3.7.0
|
| 17 |
-
matplotlib>=3.7.0
|
| 18 |
-
seaborn>=0.12.0
|
| 19 |
-
wandb>=0.15.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
retriever/__init__.py
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
from .embedder import Embedder
|
| 2 |
-
from .faiss_index import FAISSIndex
|
| 3 |
-
from .retriever import Retriever
|
| 4 |
-
from .reranker import Reranker
|
| 5 |
-
|
| 6 |
-
__all__ = ['Embedder', 'FAISSIndex', 'Retriever', 'Reranker']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
retriever/embedder.py
DELETED
|
@@ -1,49 +0,0 @@
|
|
| 1 |
-
from sentence_transformers import SentenceTransformer
|
| 2 |
-
from typing import List, Union
|
| 3 |
-
import numpy as np
|
| 4 |
-
import logging
|
| 5 |
-
|
| 6 |
-
logger = logging.getLogger(__name__)
|
| 7 |
-
|
| 8 |
-
class Embedder:
|
| 9 |
-
def __init__(self, model_name: str = "BAAI/bge-large-en-v1.5", device: str = "cuda"):
|
| 10 |
-
self.model_name = model_name
|
| 11 |
-
self.device = device
|
| 12 |
-
self.model = SentenceTransformer(model_name, device=device)
|
| 13 |
-
logger.info(f"Loaded embedding model: {model_name}")
|
| 14 |
-
|
| 15 |
-
def encode(self, texts: Union[str, List[str]], batch_size: int = 16) -> np.ndarray:
|
| 16 |
-
"""Encode texts to embeddings"""
|
| 17 |
-
if isinstance(texts, str):
|
| 18 |
-
texts = [texts]
|
| 19 |
-
|
| 20 |
-
embeddings = self.model.encode(
|
| 21 |
-
texts,
|
| 22 |
-
batch_size=batch_size,
|
| 23 |
-
convert_to_numpy=True,
|
| 24 |
-
show_progress_bar=len(texts) > 100
|
| 25 |
-
)
|
| 26 |
-
|
| 27 |
-
return embeddings
|
| 28 |
-
|
| 29 |
-
def encode_queries(self, queries: List[str], batch_size: int = 16) -> np.ndarray:
|
| 30 |
-
"""Encode queries with query prefix"""
|
| 31 |
-
if not queries:
|
| 32 |
-
return np.array([])
|
| 33 |
-
|
| 34 |
-
# Add query prefix for BGE models
|
| 35 |
-
prefixed_queries = [f"Represent this sentence for searching relevant passages: {q}" for q in queries]
|
| 36 |
-
return self.encode(prefixed_queries, batch_size)
|
| 37 |
-
|
| 38 |
-
def encode_passages(self, passages: List[str], batch_size: int = 16) -> np.ndarray:
|
| 39 |
-
"""Encode passages with passage prefix"""
|
| 40 |
-
if not passages:
|
| 41 |
-
return np.array([])
|
| 42 |
-
|
| 43 |
-
# Add passage prefix for BGE models
|
| 44 |
-
prefixed_passages = [f"Represent this sentence for searching relevant passages: {p}" for p in passages]
|
| 45 |
-
return self.encode(prefixed_passages, batch_size)
|
| 46 |
-
|
| 47 |
-
def get_dimension(self) -> int:
|
| 48 |
-
"""Get embedding dimension"""
|
| 49 |
-
return self.model.get_sentence_embedding_dimension()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
retriever/faiss_index.py
DELETED
|
@@ -1,131 +0,0 @@
|
|
| 1 |
-
# 工厂函数,供pipeline调用
|
| 2 |
-
def build_faiss_index(embeddings, texts, metadata=None, index_type="HNSW"):
|
| 3 |
-
if embeddings is None or not hasattr(embeddings, 'shape') or len(embeddings.shape) != 2 or embeddings.shape[0] == 0:
|
| 4 |
-
raise ValueError(f"Embeddings is empty or not a 2D array. Got shape: {getattr(embeddings, 'shape', None)}")
|
| 5 |
-
dimension = embeddings.shape[1]
|
| 6 |
-
index = FAISSIndex(dimension, index_type=index_type)
|
| 7 |
-
index.build_index(embeddings, texts, metadata)
|
| 8 |
-
return index
|
| 9 |
-
import faiss
|
| 10 |
-
import numpy as np
|
| 11 |
-
import pickle
|
| 12 |
-
import os
|
| 13 |
-
from typing import List, Dict, Any, Tuple
|
| 14 |
-
import logging
|
| 15 |
-
|
| 16 |
-
logger = logging.getLogger(__name__)
|
| 17 |
-
|
| 18 |
-
class FAISSIndex:
|
| 19 |
-
def __init__(self, dimension: int, index_type: str = "HNSW"):
|
| 20 |
-
self.dimension = dimension
|
| 21 |
-
self.index_type = index_type
|
| 22 |
-
self.index = None
|
| 23 |
-
self.id_to_text = {}
|
| 24 |
-
self.id_to_metadata = {}
|
| 25 |
-
self.next_id = 0
|
| 26 |
-
|
| 27 |
-
def build_index(self, embeddings: np.ndarray, texts: List[str],
|
| 28 |
-
metadata: List[Dict[str, Any]] = None) -> None:
|
| 29 |
-
"""Build FAISS index from embeddings"""
|
| 30 |
-
if embeddings.shape[1] != self.dimension:
|
| 31 |
-
raise ValueError(f"Embedding dimension {embeddings.shape[1]} != {self.dimension}")
|
| 32 |
-
# Normalize embeddings for cosine similarity
|
| 33 |
-
faiss.normalize_L2(embeddings)
|
| 34 |
-
if self.index_type == "HNSW":
|
| 35 |
-
# HNSW index for fast approximate search
|
| 36 |
-
self.index = faiss.IndexHNSWFlat(self.dimension, 32) # 32 is default M
|
| 37 |
-
self.index.hnsw.efConstruction = 200
|
| 38 |
-
self.index.add(embeddings)
|
| 39 |
-
elif self.index_type == "IVF":
|
| 40 |
-
nlist = min(4096, len(embeddings) // 100)
|
| 41 |
-
quantizer = faiss.IndexFlatIP(self.dimension)
|
| 42 |
-
self.index = faiss.IndexIVFFlat(quantizer, self.dimension, nlist)
|
| 43 |
-
self.index.train(embeddings)
|
| 44 |
-
self.index.add(embeddings)
|
| 45 |
-
else:
|
| 46 |
-
self.index = faiss.IndexFlatIP(self.dimension)
|
| 47 |
-
self.index.add(embeddings)
|
| 48 |
-
# Store text and metadata
|
| 49 |
-
for i, text in enumerate(texts):
|
| 50 |
-
self.id_to_text[i] = text
|
| 51 |
-
if metadata and i < len(metadata):
|
| 52 |
-
self.id_to_metadata[i] = metadata[i]
|
| 53 |
-
logger.info(f"Built FAISS {self.index_type} index with {len(embeddings)} vectors")
|
| 54 |
-
|
| 55 |
-
def search(self, query_embeddings: np.ndarray, k: int = 10) -> Tuple[np.ndarray, np.ndarray]:
|
| 56 |
-
"""Search for similar vectors"""
|
| 57 |
-
if self.index is None:
|
| 58 |
-
raise ValueError("Index not built yet")
|
| 59 |
-
|
| 60 |
-
# Normalize query embeddings
|
| 61 |
-
faiss.normalize_L2(query_embeddings)
|
| 62 |
-
|
| 63 |
-
# Search
|
| 64 |
-
scores, indices = self.index.search(query_embeddings, k)
|
| 65 |
-
|
| 66 |
-
return scores, indices
|
| 67 |
-
|
| 68 |
-
def get_texts(self, indices: np.ndarray) -> List[str]:
|
| 69 |
-
"""Get texts by indices"""
|
| 70 |
-
texts = []
|
| 71 |
-
for idx in indices.flatten():
|
| 72 |
-
if idx in self.id_to_text:
|
| 73 |
-
texts.append(self.id_to_text[idx])
|
| 74 |
-
else:
|
| 75 |
-
texts.append("")
|
| 76 |
-
return texts
|
| 77 |
-
|
| 78 |
-
def get_metadata(self, indices: np.ndarray) -> List[Dict[str, Any]]:
|
| 79 |
-
"""Get metadata by indices"""
|
| 80 |
-
metadata = []
|
| 81 |
-
for idx in indices.flatten():
|
| 82 |
-
if idx in self.id_to_metadata:
|
| 83 |
-
metadata.append(self.id_to_metadata[idx])
|
| 84 |
-
else:
|
| 85 |
-
metadata.append({})
|
| 86 |
-
return metadata
|
| 87 |
-
|
| 88 |
-
def save(self, path: str) -> None:
|
| 89 |
-
"""Save index to disk"""
|
| 90 |
-
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 91 |
-
|
| 92 |
-
# Save FAISS index
|
| 93 |
-
faiss.write_index(self.index, f"{path}.faiss")
|
| 94 |
-
|
| 95 |
-
# Save metadata
|
| 96 |
-
with open(f"{path}.pkl", "wb") as f:
|
| 97 |
-
pickle.dump({
|
| 98 |
-
'id_to_text': self.id_to_text,
|
| 99 |
-
'id_to_metadata': self.id_to_metadata,
|
| 100 |
-
'dimension': self.dimension,
|
| 101 |
-
'index_type': self.index_type
|
| 102 |
-
}, f)
|
| 103 |
-
|
| 104 |
-
logger.info(f"Saved index to {path}")
|
| 105 |
-
|
| 106 |
-
def load(self, path: str) -> None:
|
| 107 |
-
"""Load index from disk"""
|
| 108 |
-
# Load FAISS index
|
| 109 |
-
self.index = faiss.read_index(f"{path}.faiss")
|
| 110 |
-
|
| 111 |
-
# Load metadata
|
| 112 |
-
with open(f"{path}.pkl", "rb") as f:
|
| 113 |
-
data = pickle.load(f)
|
| 114 |
-
self.id_to_text = data['id_to_text']
|
| 115 |
-
self.id_to_metadata = data['id_to_metadata']
|
| 116 |
-
self.dimension = data['dimension']
|
| 117 |
-
self.index_type = data['index_type']
|
| 118 |
-
|
| 119 |
-
logger.info(f"Loaded index from {path}")
|
| 120 |
-
|
| 121 |
-
def get_stats(self) -> Dict[str, Any]:
|
| 122 |
-
"""Get index statistics"""
|
| 123 |
-
if self.index is None:
|
| 124 |
-
return {}
|
| 125 |
-
|
| 126 |
-
return {
|
| 127 |
-
'num_vectors': self.index.ntotal,
|
| 128 |
-
'dimension': self.dimension,
|
| 129 |
-
'index_type': self.index_type,
|
| 130 |
-
'is_trained': self.index.is_trained if hasattr(self.index, 'is_trained') else True
|
| 131 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
retriever/reranker.py
DELETED
|
@@ -1,46 +0,0 @@
|
|
| 1 |
-
from sentence_transformers import CrossEncoder
|
| 2 |
-
from typing import List
|
| 3 |
-
import numpy as np
|
| 4 |
-
import logging
|
| 5 |
-
|
| 6 |
-
logger = logging.getLogger(__name__)
|
| 7 |
-
|
| 8 |
-
class Reranker:
|
| 9 |
-
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", device: str = "cuda"):
|
| 10 |
-
self.model_name = model_name
|
| 11 |
-
self.device = device
|
| 12 |
-
self.model = CrossEncoder(model_name, device=device)
|
| 13 |
-
logger.info(f"Loaded reranker model: {model_name}")
|
| 14 |
-
|
| 15 |
-
def rerank(self, query: str, passages: List[str], batch_size: int = 32) -> List[float]:
|
| 16 |
-
"""Rerank passages for a query"""
|
| 17 |
-
if not passages:
|
| 18 |
-
return []
|
| 19 |
-
|
| 20 |
-
# Create query-passage pairs
|
| 21 |
-
pairs = [(query, passage) for passage in passages]
|
| 22 |
-
|
| 23 |
-
# Get relevance scores
|
| 24 |
-
scores = self.model.predict(pairs, batch_size=batch_size)
|
| 25 |
-
|
| 26 |
-
return scores.tolist()
|
| 27 |
-
|
| 28 |
-
def rerank_batch(self, queries: List[str], passages_list: List[List[str]],
|
| 29 |
-
batch_size: int = 32) -> List[List[float]]:
|
| 30 |
-
"""Rerank passages for multiple queries"""
|
| 31 |
-
all_scores = []
|
| 32 |
-
|
| 33 |
-
for query, passages in zip(queries, passages_list):
|
| 34 |
-
scores = self.rerank(query, passages, batch_size)
|
| 35 |
-
all_scores.append(scores)
|
| 36 |
-
|
| 37 |
-
return all_scores
|
| 38 |
-
|
| 39 |
-
def get_top_k(self, query: str, passages: List[str], k: int = 5) -> List[tuple]:
|
| 40 |
-
"""Get top-k passages with scores"""
|
| 41 |
-
scores = self.rerank(query, passages)
|
| 42 |
-
|
| 43 |
-
# Sort by score
|
| 44 |
-
ranked = sorted(zip(passages, scores), key=lambda x: x[1], reverse=True)
|
| 45 |
-
|
| 46 |
-
return ranked[:k]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
retriever/retriever.py
DELETED
|
@@ -1,104 +0,0 @@
|
|
| 1 |
-
from typing import List, Dict, Any, Tuple
|
| 2 |
-
import numpy as np
|
| 3 |
-
from .embedder import Embedder
|
| 4 |
-
from .faiss_index import FAISSIndex
|
| 5 |
-
from .reranker import Reranker
|
| 6 |
-
import logging
|
| 7 |
-
|
| 8 |
-
logger = logging.getLogger(__name__)
|
| 9 |
-
|
| 10 |
-
class Retriever:
|
| 11 |
-
def __init__(self, embedder: Embedder, index: FAISSIndex, reranker: Reranker = None):
|
| 12 |
-
self.embedder = embedder
|
| 13 |
-
self.index = index
|
| 14 |
-
self.reranker = reranker
|
| 15 |
-
|
| 16 |
-
def retrieve(self, queries: List[str], k: int = 20,
|
| 17 |
-
rerank_k: int = 10) -> List[List[Dict[str, Any]]]:
|
| 18 |
-
"""Retrieve and rerank passages for queries"""
|
| 19 |
-
if not queries:
|
| 20 |
-
return []
|
| 21 |
-
|
| 22 |
-
# Encode queries
|
| 23 |
-
query_embeddings = self.embedder.encode_queries(queries)
|
| 24 |
-
|
| 25 |
-
# Search index
|
| 26 |
-
scores, indices = self.index.search(query_embeddings, k)
|
| 27 |
-
|
| 28 |
-
# Format results
|
| 29 |
-
results = []
|
| 30 |
-
for i, query in enumerate(queries):
|
| 31 |
-
query_results = []
|
| 32 |
-
for j, (score, idx) in enumerate(zip(scores[i], indices[i])):
|
| 33 |
-
if idx == -1: # Invalid index
|
| 34 |
-
continue
|
| 35 |
-
|
| 36 |
-
text = self.index.id_to_text.get(idx, "")
|
| 37 |
-
metadata = self.index.id_to_metadata.get(idx, {})
|
| 38 |
-
|
| 39 |
-
query_results.append({
|
| 40 |
-
'text': text,
|
| 41 |
-
'score': float(score),
|
| 42 |
-
'rank': j + 1,
|
| 43 |
-
'metadata': metadata,
|
| 44 |
-
'id': idx
|
| 45 |
-
})
|
| 46 |
-
|
| 47 |
-
results.append(query_results)
|
| 48 |
-
|
| 49 |
-
# Rerank if reranker is available
|
| 50 |
-
if self.reranker and rerank_k < k:
|
| 51 |
-
reranked_results = []
|
| 52 |
-
for i, query in enumerate(queries):
|
| 53 |
-
passages = [r['text'] for r in results[i][:k]]
|
| 54 |
-
rerank_scores = self.reranker.rerank(query, passages)
|
| 55 |
-
|
| 56 |
-
# Reorder results based on rerank scores
|
| 57 |
-
reranked = sorted(
|
| 58 |
-
zip(results[i][:k], rerank_scores),
|
| 59 |
-
key=lambda x: x[1],
|
| 60 |
-
reverse=True
|
| 61 |
-
)
|
| 62 |
-
|
| 63 |
-
reranked_results.append([
|
| 64 |
-
{**result, 'rerank_score': score, 'rank': j + 1}
|
| 65 |
-
for j, (result, score) in enumerate(reranked[:rerank_k])
|
| 66 |
-
])
|
| 67 |
-
|
| 68 |
-
results = reranked_results
|
| 69 |
-
|
| 70 |
-
return results
|
| 71 |
-
|
| 72 |
-
def retrieve_single(self, query: str, k: int = 10) -> List[Dict[str, Any]]:
|
| 73 |
-
"""Retrieve for a single query"""
|
| 74 |
-
results = self.retrieve([query], k)
|
| 75 |
-
return results[0] if results else []
|
| 76 |
-
|
| 77 |
-
def batch_retrieve(self, queries: List[str], batch_size: int = 32,
|
| 78 |
-
k: int = 10) -> List[List[Dict[str, Any]]]:
|
| 79 |
-
"""Retrieve for multiple queries in batches"""
|
| 80 |
-
all_results = []
|
| 81 |
-
|
| 82 |
-
for i in range(0, len(queries), batch_size):
|
| 83 |
-
batch_queries = queries[i:i + batch_size]
|
| 84 |
-
batch_results = self.retrieve(batch_queries, k)
|
| 85 |
-
all_results.extend(batch_results)
|
| 86 |
-
|
| 87 |
-
return all_results
|
| 88 |
-
|
| 89 |
-
def get_retrieval_stats(self, queries: List[str], k: int = 10) -> Dict[str, Any]:
|
| 90 |
-
"""Get retrieval statistics"""
|
| 91 |
-
results = self.retrieve(queries, k)
|
| 92 |
-
|
| 93 |
-
scores = []
|
| 94 |
-
for query_results in results:
|
| 95 |
-
scores.extend([r['score'] for r in query_results])
|
| 96 |
-
|
| 97 |
-
return {
|
| 98 |
-
'num_queries': len(queries),
|
| 99 |
-
'avg_scores': np.mean(scores) if scores else 0,
|
| 100 |
-
'std_scores': np.std(scores) if scores else 0,
|
| 101 |
-
'min_scores': np.min(scores) if scores else 0,
|
| 102 |
-
'max_scores': np.max(scores) if scores else 0,
|
| 103 |
-
'avg_results_per_query': np.mean([len(r) for r in results])
|
| 104 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
simple_e2e_test.py
DELETED
|
@@ -1,518 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
"""
|
| 4 |
-
SafeRAG Simple End-to-End Test
|
| 5 |
-
Complete workflow test without external dependencies
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
import sys
|
| 9 |
-
import os
|
| 10 |
-
import time
|
| 11 |
-
import random
|
| 12 |
-
import math
|
| 13 |
-
|
| 14 |
-
# Add project root to path
|
| 15 |
-
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 16 |
-
|
| 17 |
-
def test_basic_functionality():
|
| 18 |
-
"""Test basic Python functionality"""
|
| 19 |
-
print("Testing basic functionality...")
|
| 20 |
-
|
| 21 |
-
try:
|
| 22 |
-
# Test basic operations
|
| 23 |
-
assert 1 + 1 == 2, "Basic math failed"
|
| 24 |
-
assert "hello" + " " + "world" == "hello world", "String concatenation failed"
|
| 25 |
-
assert len([1, 2, 3]) == 3, "List length failed"
|
| 26 |
-
print("+ Basic Python operations work")
|
| 27 |
-
|
| 28 |
-
# Test random number generation
|
| 29 |
-
random.seed(42)
|
| 30 |
-
rand_num = random.random()
|
| 31 |
-
assert 0 <= rand_num <= 1, "Random number out of range"
|
| 32 |
-
print("+ Random number generation works")
|
| 33 |
-
|
| 34 |
-
return True
|
| 35 |
-
except Exception as e:
|
| 36 |
-
print("✗ Basic functionality test failed:", e)
|
| 37 |
-
return False
|
| 38 |
-
|
| 39 |
-
def test_text_processing():
|
| 40 |
-
"""Test text processing functionality"""
|
| 41 |
-
print("\nTesting text processing...")
|
| 42 |
-
|
| 43 |
-
try:
|
| 44 |
-
# Simple text cleaning
|
| 45 |
-
def clean_text(text):
|
| 46 |
-
if not text:
|
| 47 |
-
return ""
|
| 48 |
-
# Remove extra whitespace
|
| 49 |
-
import re
|
| 50 |
-
text = re.sub(r'\s+', ' ', text)
|
| 51 |
-
# Remove special characters but keep punctuation
|
| 52 |
-
text = re.sub(r'[^\w\s\.\,\!\?\;\:\-\(\)]', '', text)
|
| 53 |
-
return text.strip()
|
| 54 |
-
|
| 55 |
-
# Test text cleaning
|
| 56 |
-
test_text = " This is a test text!!! "
|
| 57 |
-
cleaned = clean_text(test_text)
|
| 58 |
-
expected = "This is a test text!!!"
|
| 59 |
-
assert cleaned == expected, "Text cleaning failed: got '{}', expected '{}'".format(cleaned, expected)
|
| 60 |
-
print("+ Text cleaning works")
|
| 61 |
-
|
| 62 |
-
# Test sentence extraction
|
| 63 |
-
def extract_sentences(text):
|
| 64 |
-
sentences = text.split('.')
|
| 65 |
-
return [clean_text(s) for s in sentences if s.strip()]
|
| 66 |
-
|
| 67 |
-
test_text = "First sentence. Second sentence. Third sentence."
|
| 68 |
-
sentences = extract_sentences(test_text)
|
| 69 |
-
assert len(sentences) == 3, "Sentence extraction failed: got {} sentences, expected 3".format(len(sentences))
|
| 70 |
-
print("+ Sentence extraction works")
|
| 71 |
-
|
| 72 |
-
return True
|
| 73 |
-
except Exception as e:
|
| 74 |
-
print("✗ Text processing test failed:", e)
|
| 75 |
-
return False
|
| 76 |
-
|
| 77 |
-
def test_simple_embeddings():
|
| 78 |
-
"""Test simple embedding simulation"""
|
| 79 |
-
print("\nTesting simple embeddings...")
|
| 80 |
-
|
| 81 |
-
try:
|
| 82 |
-
# Simple embedding simulation using random numbers
|
| 83 |
-
def create_simple_embeddings(texts, dim=10):
|
| 84 |
-
"""Create simple random embeddings for testing"""
|
| 85 |
-
random.seed(42) # For reproducibility
|
| 86 |
-
embeddings = []
|
| 87 |
-
for text in texts:
|
| 88 |
-
embedding = [random.random() for _ in range(dim)]
|
| 89 |
-
# Simple normalization
|
| 90 |
-
norm = math.sqrt(sum(x*x for x in embedding))
|
| 91 |
-
if norm > 0:
|
| 92 |
-
embedding = [x/norm for x in embedding]
|
| 93 |
-
embeddings.append(embedding)
|
| 94 |
-
return embeddings
|
| 95 |
-
|
| 96 |
-
# Test embedding creation
|
| 97 |
-
texts = ["This is a test", "Another test sentence"]
|
| 98 |
-
embeddings = create_simple_embeddings(texts)
|
| 99 |
-
assert len(embeddings) == 2, "Wrong number of embeddings"
|
| 100 |
-
assert len(embeddings[0]) == 10, "Wrong embedding dimension"
|
| 101 |
-
print("+ Simple embedding creation works")
|
| 102 |
-
|
| 103 |
-
# Test similarity calculation
|
| 104 |
-
def cosine_similarity(a, b):
|
| 105 |
-
dot_product = sum(x * y for x, y in zip(a, b))
|
| 106 |
-
norm_a = math.sqrt(sum(x*x for x in a))
|
| 107 |
-
norm_b = math.sqrt(sum(x*x for x in b))
|
| 108 |
-
if norm_a == 0 or norm_b == 0:
|
| 109 |
-
return 0
|
| 110 |
-
return dot_product / (norm_a * norm_b)
|
| 111 |
-
|
| 112 |
-
sim = cosine_similarity(embeddings[0], embeddings[1])
|
| 113 |
-
assert 0 <= sim <= 1, "Similarity score out of range: {}".format(sim)
|
| 114 |
-
print("+ Similarity calculation works")
|
| 115 |
-
|
| 116 |
-
return True
|
| 117 |
-
except Exception as e:
|
| 118 |
-
print("✗ Simple embeddings test failed:", e)
|
| 119 |
-
return False
|
| 120 |
-
|
| 121 |
-
def test_simple_retrieval():
|
| 122 |
-
"""Test simple retrieval functionality"""
|
| 123 |
-
print("\nTesting simple retrieval...")
|
| 124 |
-
|
| 125 |
-
try:
|
| 126 |
-
# Simple retrieval simulation
|
| 127 |
-
class SimpleRetriever:
|
| 128 |
-
def __init__(self, passages, embeddings):
|
| 129 |
-
self.passages = passages
|
| 130 |
-
self.embeddings = embeddings
|
| 131 |
-
|
| 132 |
-
def search(self, query_embedding, k=5):
|
| 133 |
-
# Calculate similarities
|
| 134 |
-
similarities = []
|
| 135 |
-
for embedding in self.embeddings:
|
| 136 |
-
sim = sum(x * y for x, y in zip(embedding, query_embedding))
|
| 137 |
-
similarities.append(sim)
|
| 138 |
-
|
| 139 |
-
# Get top-k indices
|
| 140 |
-
indexed_sims = [(i, sim) for i, sim in enumerate(similarities)]
|
| 141 |
-
indexed_sims.sort(key=lambda x: x[1], reverse=True)
|
| 142 |
-
top_indices = [i for i, _ in indexed_sims[:k]]
|
| 143 |
-
|
| 144 |
-
# Return results
|
| 145 |
-
results = []
|
| 146 |
-
for i, idx in enumerate(top_indices):
|
| 147 |
-
results.append({
|
| 148 |
-
'text': self.passages[idx],
|
| 149 |
-
'score': similarities[idx],
|
| 150 |
-
'rank': i + 1
|
| 151 |
-
})
|
| 152 |
-
return results
|
| 153 |
-
|
| 154 |
-
# Create test data
|
| 155 |
-
passages = [
|
| 156 |
-
"Machine learning is a subset of artificial intelligence.",
|
| 157 |
-
"Deep learning uses neural networks with multiple layers.",
|
| 158 |
-
"Natural language processing deals with text and speech.",
|
| 159 |
-
"Computer vision focuses on image and video analysis."
|
| 160 |
-
]
|
| 161 |
-
|
| 162 |
-
# Create simple embeddings
|
| 163 |
-
def create_simple_embeddings(texts, dim=10):
|
| 164 |
-
random.seed(42)
|
| 165 |
-
embeddings = []
|
| 166 |
-
for text in texts:
|
| 167 |
-
embedding = [random.random() for _ in range(dim)]
|
| 168 |
-
norm = math.sqrt(sum(x*x for x in embedding))
|
| 169 |
-
if norm > 0:
|
| 170 |
-
embedding = [x/norm for x in embedding]
|
| 171 |
-
embeddings.append(embedding)
|
| 172 |
-
return embeddings
|
| 173 |
-
|
| 174 |
-
embeddings = create_simple_embeddings(passages)
|
| 175 |
-
|
| 176 |
-
# Test retrieval
|
| 177 |
-
retriever = SimpleRetriever(passages, embeddings)
|
| 178 |
-
query_embedding = [random.random() for _ in range(10)]
|
| 179 |
-
norm = math.sqrt(sum(x*x for x in query_embedding))
|
| 180 |
-
if norm > 0:
|
| 181 |
-
query_embedding = [x/norm for x in query_embedding]
|
| 182 |
-
|
| 183 |
-
results = retriever.search(query_embedding, k=3)
|
| 184 |
-
assert len(results) == 3, "Retrieval returned wrong number of results: {}".format(len(results))
|
| 185 |
-
assert all('text' in r and 'score' in r for r in results), "Retrieval results missing fields"
|
| 186 |
-
print("+ Simple retrieval works")
|
| 187 |
-
|
| 188 |
-
return True
|
| 189 |
-
except Exception as e:
|
| 190 |
-
print("✗ Simple retrieval test failed:", e)
|
| 191 |
-
return False
|
| 192 |
-
|
| 193 |
-
def test_risk_calibration():
|
| 194 |
-
"""Test risk calibration functionality"""
|
| 195 |
-
print("\nTesting risk calibration...")
|
| 196 |
-
|
| 197 |
-
try:
|
| 198 |
-
# Simple risk feature extraction
|
| 199 |
-
def extract_risk_features(question, retrieved_passages):
|
| 200 |
-
features = {}
|
| 201 |
-
|
| 202 |
-
if not retrieved_passages:
|
| 203 |
-
return {'num_passages': 0, 'avg_similarity': 0.0, 'diversity': 0.0}
|
| 204 |
-
|
| 205 |
-
# Basic features
|
| 206 |
-
features['num_passages'] = len(retrieved_passages)
|
| 207 |
-
scores = [p['score'] for p in retrieved_passages]
|
| 208 |
-
features['avg_similarity'] = sum(scores) / len(scores)
|
| 209 |
-
features['max_similarity'] = max(scores)
|
| 210 |
-
features['min_similarity'] = min(scores)
|
| 211 |
-
|
| 212 |
-
# Simple diversity calculation
|
| 213 |
-
if len(scores) > 1:
|
| 214 |
-
mean_score = features['avg_similarity']
|
| 215 |
-
variance = sum((x - mean_score) ** 2 for x in scores) / len(scores)
|
| 216 |
-
features['diversity'] = 1.0 - math.sqrt(variance)
|
| 217 |
-
else:
|
| 218 |
-
features['diversity'] = 1.0
|
| 219 |
-
|
| 220 |
-
return features
|
| 221 |
-
|
| 222 |
-
# Simple risk prediction
|
| 223 |
-
def predict_risk(features):
|
| 224 |
-
# Simple heuristic for risk scoring
|
| 225 |
-
risk_score = 0.0
|
| 226 |
-
|
| 227 |
-
# Few passages = higher risk
|
| 228 |
-
if features['num_passages'] < 3:
|
| 229 |
-
risk_score += 0.3
|
| 230 |
-
|
| 231 |
-
# Low similarity = higher risk
|
| 232 |
-
if features['avg_similarity'] < 0.5:
|
| 233 |
-
risk_score += 0.2
|
| 234 |
-
|
| 235 |
-
# Low diversity = higher risk
|
| 236 |
-
if features['diversity'] < 0.3:
|
| 237 |
-
risk_score += 0.2
|
| 238 |
-
|
| 239 |
-
return min(1.0, risk_score)
|
| 240 |
-
|
| 241 |
-
# Test risk feature extraction
|
| 242 |
-
question = "What is machine learning?"
|
| 243 |
-
passages = [
|
| 244 |
-
{'text': 'ML is AI subset', 'score': 0.8},
|
| 245 |
-
{'text': 'Neural networks are used', 'score': 0.7},
|
| 246 |
-
{'text': 'Deep learning is popular', 'score': 0.6}
|
| 247 |
-
]
|
| 248 |
-
|
| 249 |
-
features = extract_risk_features(question, passages)
|
| 250 |
-
assert 'num_passages' in features, "Missing num_passages feature"
|
| 251 |
-
assert features['num_passages'] == 3, "Wrong number of passages: {}".format(features['num_passages'])
|
| 252 |
-
print("+ Risk feature extraction works")
|
| 253 |
-
|
| 254 |
-
# Test risk prediction
|
| 255 |
-
risk_score = predict_risk(features)
|
| 256 |
-
assert 0 <= risk_score <= 1, "Risk score out of range: {}".format(risk_score)
|
| 257 |
-
print("+ Risk prediction works")
|
| 258 |
-
|
| 259 |
-
return True
|
| 260 |
-
except Exception as e:
|
| 261 |
-
print("✗ Risk calibration test failed:", e)
|
| 262 |
-
return False
|
| 263 |
-
|
| 264 |
-
def test_generation():
|
| 265 |
-
"""Test generation functionality"""
|
| 266 |
-
print("\nTesting generation...")
|
| 267 |
-
|
| 268 |
-
try:
|
| 269 |
-
# Simple generation simulation
|
| 270 |
-
def generate_answer(question, retrieved_passages, risk_score):
|
| 271 |
-
# Simple template-based generation
|
| 272 |
-
context = " ".join([p['text'] for p in retrieved_passages[:3]])
|
| 273 |
-
|
| 274 |
-
if risk_score < 0.3:
|
| 275 |
-
# Low risk: confident answer
|
| 276 |
-
answer = "Based on the information: {}. The answer is: {}.".format(
|
| 277 |
-
context, "This is a confident answer."
|
| 278 |
-
)
|
| 279 |
-
elif risk_score < 0.7:
|
| 280 |
-
# Medium risk: cautious answer
|
| 281 |
-
answer = "Based on the available information: {}. The answer might be: {}.".format(
|
| 282 |
-
context, "This is a cautious answer."
|
| 283 |
-
)
|
| 284 |
-
else:
|
| 285 |
-
# High risk: uncertain answer
|
| 286 |
-
answer = "The available information: {} is limited. I'm not certain, but it might be: {}.".format(
|
| 287 |
-
context, "This is an uncertain answer."
|
| 288 |
-
)
|
| 289 |
-
|
| 290 |
-
return answer
|
| 291 |
-
|
| 292 |
-
# Test generation
|
| 293 |
-
question = "What is machine learning?"
|
| 294 |
-
passages = [
|
| 295 |
-
{'text': 'Machine learning is AI subset', 'score': 0.8},
|
| 296 |
-
{'text': 'It uses algorithms', 'score': 0.7}
|
| 297 |
-
]
|
| 298 |
-
|
| 299 |
-
# Test different risk levels
|
| 300 |
-
for risk_score in [0.2, 0.5, 0.8]:
|
| 301 |
-
answer = generate_answer(question, passages, risk_score)
|
| 302 |
-
assert len(answer) > 0, "Empty answer generated"
|
| 303 |
-
assert "machine learning" in answer.lower() or "ai" in answer.lower(), "Answer doesn't address question"
|
| 304 |
-
|
| 305 |
-
print("+ Generation works")
|
| 306 |
-
|
| 307 |
-
return True
|
| 308 |
-
except Exception as e:
|
| 309 |
-
print("✗ Generation test failed:", e)
|
| 310 |
-
return False
|
| 311 |
-
|
| 312 |
-
def test_evaluation():
|
| 313 |
-
"""Test evaluation functionality"""
|
| 314 |
-
print("\nTesting evaluation...")
|
| 315 |
-
|
| 316 |
-
try:
|
| 317 |
-
# Simple evaluation metrics
|
| 318 |
-
def exact_match(prediction, reference):
|
| 319 |
-
return prediction.lower().strip() == reference.lower().strip()
|
| 320 |
-
|
| 321 |
-
def f1_score(prediction, reference):
|
| 322 |
-
pred_words = set(prediction.lower().split())
|
| 323 |
-
ref_words = set(reference.lower().split())
|
| 324 |
-
|
| 325 |
-
if len(ref_words) == 0:
|
| 326 |
-
return 1.0 if len(pred_words) == 0 else 0.0
|
| 327 |
-
|
| 328 |
-
common = pred_words & ref_words
|
| 329 |
-
precision = len(common) / len(pred_words) if pred_words else 0.0
|
| 330 |
-
recall = len(common) / len(ref_words)
|
| 331 |
-
|
| 332 |
-
if precision + recall == 0:
|
| 333 |
-
return 0.0
|
| 334 |
-
|
| 335 |
-
return 2 * precision * recall / (precision + recall)
|
| 336 |
-
|
| 337 |
-
# Test evaluation
|
| 338 |
-
predictions = ["Machine learning is AI", "Deep learning uses neural networks"]
|
| 339 |
-
references = ["Machine learning is AI", "Deep learning uses neural networks"]
|
| 340 |
-
|
| 341 |
-
# Test exact match
|
| 342 |
-
em_scores = [exact_match(p, r) for p, r in zip(predictions, references)]
|
| 343 |
-
assert all(em_scores), "Exact match failed"
|
| 344 |
-
print("+ Exact match evaluation works")
|
| 345 |
-
|
| 346 |
-
# Test F1 score
|
| 347 |
-
f1_scores = [f1_score(p, r) for p, r in zip(predictions, references)]
|
| 348 |
-
assert all(0 <= score <= 1 for score in f1_scores), "F1 scores out of range"
|
| 349 |
-
print("+ F1 score evaluation works")
|
| 350 |
-
|
| 351 |
-
return True
|
| 352 |
-
except Exception as e:
|
| 353 |
-
print("✗ Evaluation test failed:", e)
|
| 354 |
-
return False
|
| 355 |
-
|
| 356 |
-
def test_end_to_end_workflow():
|
| 357 |
-
"""Test complete end-to-end workflow"""
|
| 358 |
-
print("\nTesting end-to-end workflow...")
|
| 359 |
-
|
| 360 |
-
try:
|
| 361 |
-
# Simulate complete RAG pipeline
|
| 362 |
-
def rag_pipeline(question):
|
| 363 |
-
# Step 1: Create simple embeddings
|
| 364 |
-
passages = [
|
| 365 |
-
"Machine learning is a subset of artificial intelligence.",
|
| 366 |
-
"Deep learning uses neural networks with multiple layers.",
|
| 367 |
-
"Natural language processing deals with text and speech.",
|
| 368 |
-
"Computer vision focuses on image and video analysis."
|
| 369 |
-
]
|
| 370 |
-
|
| 371 |
-
# Simulate embeddings
|
| 372 |
-
random.seed(42)
|
| 373 |
-
embeddings = []
|
| 374 |
-
for passage in passages:
|
| 375 |
-
embedding = [random.random() for _ in range(10)]
|
| 376 |
-
norm = math.sqrt(sum(x*x for x in embedding))
|
| 377 |
-
if norm > 0:
|
| 378 |
-
embedding = [x/norm for x in embedding]
|
| 379 |
-
embeddings.append(embedding)
|
| 380 |
-
|
| 381 |
-
# Step 2: Retrieve relevant passages
|
| 382 |
-
query_embedding = [random.random() for _ in range(10)]
|
| 383 |
-
norm = math.sqrt(sum(x*x for x in query_embedding))
|
| 384 |
-
if norm > 0:
|
| 385 |
-
query_embedding = [x/norm for x in query_embedding]
|
| 386 |
-
|
| 387 |
-
similarities = []
|
| 388 |
-
for embedding in embeddings:
|
| 389 |
-
sim = sum(x * y for x, y in zip(embedding, query_embedding))
|
| 390 |
-
similarities.append(sim)
|
| 391 |
-
|
| 392 |
-
indexed_sims = [(i, sim) for i, sim in enumerate(similarities)]
|
| 393 |
-
indexed_sims.sort(key=lambda x: x[1], reverse=True)
|
| 394 |
-
top_indices = [i for i, _ in indexed_sims[:3]]
|
| 395 |
-
|
| 396 |
-
retrieved_passages = []
|
| 397 |
-
for i, idx in enumerate(top_indices):
|
| 398 |
-
retrieved_passages.append({
|
| 399 |
-
'text': passages[idx],
|
| 400 |
-
'score': similarities[idx],
|
| 401 |
-
'rank': i + 1
|
| 402 |
-
})
|
| 403 |
-
|
| 404 |
-
# Step 3: Extract risk features
|
| 405 |
-
scores = [p['score'] for p in retrieved_passages]
|
| 406 |
-
features = {
|
| 407 |
-
'num_passages': len(retrieved_passages),
|
| 408 |
-
'avg_similarity': sum(scores) / len(scores) if scores else 0.0,
|
| 409 |
-
'diversity': 1.0 - math.sqrt(sum((x - sum(scores)/len(scores))**2 for x in scores) / len(scores)) if len(scores) > 1 else 1.0
|
| 410 |
-
}
|
| 411 |
-
|
| 412 |
-
# Step 4: Predict risk
|
| 413 |
-
risk_score = 0.0
|
| 414 |
-
if features['num_passages'] < 3:
|
| 415 |
-
risk_score += 0.3
|
| 416 |
-
if features['avg_similarity'] < 0.5:
|
| 417 |
-
risk_score += 0.2
|
| 418 |
-
if features['diversity'] < 0.3:
|
| 419 |
-
risk_score += 0.2
|
| 420 |
-
risk_score = min(1.0, risk_score)
|
| 421 |
-
|
| 422 |
-
# Step 5: Generate answer
|
| 423 |
-
context = " ".join([p['text'] for p in retrieved_passages[:3]])
|
| 424 |
-
if risk_score < 0.3:
|
| 425 |
-
answer = "Based on the information: {}. The answer is: Machine learning is a subset of AI.".format(context)
|
| 426 |
-
elif risk_score < 0.7:
|
| 427 |
-
answer = "Based on the available information: {}. The answer might be: Machine learning is likely a subset of AI.".format(context)
|
| 428 |
-
else:
|
| 429 |
-
answer = "The available information: {} is limited. I'm not certain, but it might be: Machine learning could be related to AI.".format(context)
|
| 430 |
-
|
| 431 |
-
return {
|
| 432 |
-
'question': question,
|
| 433 |
-
'answer': answer,
|
| 434 |
-
'retrieved_passages': retrieved_passages,
|
| 435 |
-
'risk_score': risk_score,
|
| 436 |
-
'features': features
|
| 437 |
-
}
|
| 438 |
-
|
| 439 |
-
# Test complete pipeline
|
| 440 |
-
question = "What is machine learning?"
|
| 441 |
-
result = rag_pipeline(question)
|
| 442 |
-
|
| 443 |
-
# Validate result
|
| 444 |
-
assert 'question' in result, "Missing question in result"
|
| 445 |
-
assert 'answer' in result, "Missing answer in result"
|
| 446 |
-
assert 'retrieved_passages' in result, "Missing retrieved passages"
|
| 447 |
-
assert 'risk_score' in result, "Missing risk score"
|
| 448 |
-
assert 'features' in result, "Missing features"
|
| 449 |
-
|
| 450 |
-
assert result['question'] == question, "Question not preserved"
|
| 451 |
-
assert len(result['answer']) > 0, "Empty answer"
|
| 452 |
-
assert len(result['retrieved_passages']) > 0, "No retrieved passages"
|
| 453 |
-
assert 0 <= result['risk_score'] <= 1, "Risk score out of range: {}".format(result['risk_score'])
|
| 454 |
-
|
| 455 |
-
print("+ End-to-end workflow works")
|
| 456 |
-
print(" Question: {}".format(result['question']))
|
| 457 |
-
print(" Answer: {}".format(result['answer'][:100] + "..."))
|
| 458 |
-
print(" Risk Score: {:.3f}".format(result['risk_score']))
|
| 459 |
-
print(" Retrieved Passages: {}".format(len(result['retrieved_passages'])))
|
| 460 |
-
|
| 461 |
-
return True
|
| 462 |
-
except Exception as e:
|
| 463 |
-
print("✗ End-to-end workflow test failed:", e)
|
| 464 |
-
return False
|
| 465 |
-
|
| 466 |
-
def main():
|
| 467 |
-
"""Run all end-to-end tests"""
|
| 468 |
-
print("SafeRAG Simple End-to-End Test Suite")
|
| 469 |
-
print("=" * 50)
|
| 470 |
-
|
| 471 |
-
start_time = time.time()
|
| 472 |
-
|
| 473 |
-
tests = [
|
| 474 |
-
test_basic_functionality,
|
| 475 |
-
test_text_processing,
|
| 476 |
-
test_simple_embeddings,
|
| 477 |
-
test_simple_retrieval,
|
| 478 |
-
test_risk_calibration,
|
| 479 |
-
test_generation,
|
| 480 |
-
test_evaluation,
|
| 481 |
-
test_end_to_end_workflow
|
| 482 |
-
]
|
| 483 |
-
|
| 484 |
-
passed = 0
|
| 485 |
-
total = len(tests)
|
| 486 |
-
|
| 487 |
-
for test in tests:
|
| 488 |
-
try:
|
| 489 |
-
if test():
|
| 490 |
-
passed += 1
|
| 491 |
-
except Exception as e:
|
| 492 |
-
print("✗ Test {} failed with exception: {}".format(test.__name__, e))
|
| 493 |
-
|
| 494 |
-
end_time = time.time()
|
| 495 |
-
|
| 496 |
-
print("\n" + "=" * 50)
|
| 497 |
-
print("Test Results:")
|
| 498 |
-
print("Passed: {}/{}".format(passed, total))
|
| 499 |
-
print("Time: {:.2f} seconds".format(end_time - start_time))
|
| 500 |
-
|
| 501 |
-
if passed == total:
|
| 502 |
-
print("✓ All tests passed! SafeRAG end-to-end workflow is working.")
|
| 503 |
-
print("\nThe system can:")
|
| 504 |
-
print("- Process text and extract sentences")
|
| 505 |
-
print("- Create simple embeddings and calculate similarities")
|
| 506 |
-
print("- Retrieve relevant passages based on similarity")
|
| 507 |
-
print("- Extract risk features and predict risk scores")
|
| 508 |
-
print("- Generate answers with different risk-aware strategies")
|
| 509 |
-
print("- Evaluate answers using standard metrics")
|
| 510 |
-
print("- Run complete end-to-end RAG pipeline")
|
| 511 |
-
return True
|
| 512 |
-
else:
|
| 513 |
-
print("✗ Some tests failed. Please check the errors above.")
|
| 514 |
-
return False
|
| 515 |
-
|
| 516 |
-
if __name__ == "__main__":
|
| 517 |
-
success = main()
|
| 518 |
-
sys.exit(0 if success else 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
simple_test.py
DELETED
|
@@ -1,167 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
"""
|
| 4 |
-
Simple SafeRAG Test
|
| 5 |
-
Basic functionality test without complex dependencies
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
import sys
|
| 9 |
-
import os
|
| 10 |
-
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 11 |
-
|
| 12 |
-
def test_imports():
|
| 13 |
-
"""Test that all modules can be imported"""
|
| 14 |
-
print("Testing imports...")
|
| 15 |
-
|
| 16 |
-
try:
|
| 17 |
-
from data_processing import DataLoader, Preprocessor
|
| 18 |
-
print("+ DataLoader and Preprocessor imported successfully")
|
| 19 |
-
except Exception as e:
|
| 20 |
-
print("✗ Failed to import DataLoader/Preprocessor:", e)
|
| 21 |
-
return False
|
| 22 |
-
|
| 23 |
-
try:
|
| 24 |
-
from retriever import Embedder, FAISSIndex, Retriever, Reranker
|
| 25 |
-
print("+ Retriever modules imported successfully")
|
| 26 |
-
except Exception as e:
|
| 27 |
-
print("✗ Failed to import retriever modules:", e)
|
| 28 |
-
return False
|
| 29 |
-
|
| 30 |
-
try:
|
| 31 |
-
from generator import VLLMServer, SafeGenerator, PromptTemplates
|
| 32 |
-
print("+ Generator modules imported successfully")
|
| 33 |
-
except Exception as e:
|
| 34 |
-
print("✗ Failed to import generator modules:", e)
|
| 35 |
-
return False
|
| 36 |
-
|
| 37 |
-
try:
|
| 38 |
-
from calibration import RiskFeatureExtractor, CalibrationHead
|
| 39 |
-
print("+ Calibration modules imported successfully")
|
| 40 |
-
except Exception as e:
|
| 41 |
-
print("✗ Failed to import calibration modules:", e)
|
| 42 |
-
return False
|
| 43 |
-
|
| 44 |
-
try:
|
| 45 |
-
from eval import QAEvaluator, AttributionEvaluator, CalibrationEvaluator
|
| 46 |
-
print("+ Evaluation modules imported successfully")
|
| 47 |
-
except Exception as e:
|
| 48 |
-
print("✗ Failed to import evaluation modules:", e)
|
| 49 |
-
return False
|
| 50 |
-
|
| 51 |
-
return True
|
| 52 |
-
|
| 53 |
-
def test_basic_functionality():
|
| 54 |
-
"""Test basic functionality without heavy dependencies"""
|
| 55 |
-
print("\nTesting basic functionality...")
|
| 56 |
-
|
| 57 |
-
try:
|
| 58 |
-
# Test Preprocessor
|
| 59 |
-
from data_processing.preprocessor import Preprocessor
|
| 60 |
-
preprocessor = Preprocessor()
|
| 61 |
-
|
| 62 |
-
# Test text cleaning
|
| 63 |
-
text = " This is a test text. "
|
| 64 |
-
cleaned = preprocessor.clean_text(text)
|
| 65 |
-
assert cleaned == "This is a test text.", "Expected 'This is a test text.', got '{}'".format(cleaned)
|
| 66 |
-
print("+ Text cleaning works")
|
| 67 |
-
|
| 68 |
-
# Test sentence extraction
|
| 69 |
-
text = "First sentence. Second sentence. Third sentence."
|
| 70 |
-
sentences = preprocessor.extract_sentences(text)
|
| 71 |
-
assert len(sentences) == 3, "Expected 3 sentences, got {}".format(len(sentences))
|
| 72 |
-
print("+ Sentence extraction works")
|
| 73 |
-
|
| 74 |
-
except Exception as e:
|
| 75 |
-
print("✗ Preprocessor test failed:", e)
|
| 76 |
-
return False
|
| 77 |
-
|
| 78 |
-
try:
|
| 79 |
-
# Test PromptTemplates
|
| 80 |
-
from generator.prompt_templates import PromptTemplates
|
| 81 |
-
templates = PromptTemplates()
|
| 82 |
-
|
| 83 |
-
# Test prompt formatting
|
| 84 |
-
prompt = templates.format_prompt(
|
| 85 |
-
'rag',
|
| 86 |
-
question="What is AI?",
|
| 87 |
-
context="AI is artificial intelligence."
|
| 88 |
-
)
|
| 89 |
-
assert "What is AI?" in prompt, "Question not found in prompt"
|
| 90 |
-
assert "AI is artificial intelligence." in prompt, "Context not found in prompt"
|
| 91 |
-
print("+ Prompt templates work")
|
| 92 |
-
|
| 93 |
-
except Exception as e:
|
| 94 |
-
print("✗ PromptTemplates test failed:", e)
|
| 95 |
-
return False
|
| 96 |
-
|
| 97 |
-
try:
|
| 98 |
-
# Test QAEvaluator
|
| 99 |
-
from eval.eval_qa import QAEvaluator
|
| 100 |
-
evaluator = QAEvaluator()
|
| 101 |
-
|
| 102 |
-
# Test exact match
|
| 103 |
-
predictions = ["Paris", "Paris"]
|
| 104 |
-
references = ["Paris", "London"]
|
| 105 |
-
em = evaluator.exact_match(predictions, references)
|
| 106 |
-
assert em == 0.5, "Expected 0.5, got {}".format(em)
|
| 107 |
-
print("+ QA evaluation works")
|
| 108 |
-
|
| 109 |
-
except Exception as e:
|
| 110 |
-
print("✗ QAEvaluator test failed:", e)
|
| 111 |
-
return False
|
| 112 |
-
|
| 113 |
-
return True
|
| 114 |
-
|
| 115 |
-
def test_config():
|
| 116 |
-
"""Test configuration loading"""
|
| 117 |
-
print("\nTesting configuration...")
|
| 118 |
-
|
| 119 |
-
try:
|
| 120 |
-
import yaml
|
| 121 |
-
with open('config.yaml', 'r') as f:
|
| 122 |
-
config = yaml.safe_load(f)
|
| 123 |
-
|
| 124 |
-
# Check required sections
|
| 125 |
-
required_sections = ['models', 'data', 'index', 'retrieval', 'calibration', 'evaluation']
|
| 126 |
-
for section in required_sections:
|
| 127 |
-
assert section in config, "Missing config section: {}".format(section)
|
| 128 |
-
|
| 129 |
-
print("+ Configuration file is valid")
|
| 130 |
-
return True
|
| 131 |
-
|
| 132 |
-
except Exception as e:
|
| 133 |
-
print("✗ Configuration test failed:", e)
|
| 134 |
-
return False
|
| 135 |
-
|
| 136 |
-
def main():
|
| 137 |
-
"""Run all tests"""
|
| 138 |
-
print("SafeRAG Simple Test Suite")
|
| 139 |
-
print("=" * 40)
|
| 140 |
-
|
| 141 |
-
all_passed = True
|
| 142 |
-
|
| 143 |
-
# Test imports
|
| 144 |
-
if not test_imports():
|
| 145 |
-
all_passed = False
|
| 146 |
-
|
| 147 |
-
# Test basic functionality
|
| 148 |
-
if not test_basic_functionality():
|
| 149 |
-
all_passed = False
|
| 150 |
-
|
| 151 |
-
# Test configuration
|
| 152 |
-
if not test_config():
|
| 153 |
-
all_passed = False
|
| 154 |
-
|
| 155 |
-
print("\n" + "=" * 40)
|
| 156 |
-
if all_passed:
|
| 157 |
-
print("+ All tests passed!")
|
| 158 |
-
print("SafeRAG is ready to use.")
|
| 159 |
-
else:
|
| 160 |
-
print("✗ Some tests failed.")
|
| 161 |
-
print("Please check the errors above.")
|
| 162 |
-
|
| 163 |
-
return all_passed
|
| 164 |
-
|
| 165 |
-
if __name__ == "__main__":
|
| 166 |
-
success = main()
|
| 167 |
-
sys.exit(0 if success else 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|