.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore DELETED
@@ -1,49 +0,0 @@
1
- # Python
2
- __pycache__/
3
- *.py[cod]
4
- *$py.class
5
- *.so
6
- .Python
7
- build/
8
- develop-eggs/
9
- dist/
10
- downloads/
11
- eggs/
12
- .eggs/
13
- lib/
14
- lib64/
15
- parts/
16
- sdist/
17
- var/
18
- wheels/
19
- *.egg-info/
20
- .installed.cfg
21
- *.egg
22
-
23
- # Virtual environments
24
- venv/
25
- env/
26
- ENV/
27
-
28
- # IDE
29
- .vscode/
30
- .idea/
31
- *.swp
32
- *.swo
33
-
34
- # OS
35
- .DS_Store
36
- Thumbs.db
37
-
38
- # Project specific
39
- cache/
40
- logs/
41
- results/
42
- models/
43
- index/
44
- data/
45
- *.log
46
-
47
- # Temporary files
48
- *.tmp
49
- *.temp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PROJECT_INFO.md DELETED
@@ -1,141 +0,0 @@
1
- # SafeRAG 项目信息
2
-
3
- ## 📁 项目结构
4
-
5
- ```
6
- safe_rag/
7
- ├── app.py # Gradio 演示应用
8
- ├── requirements.txt # Python 依赖
9
- ├── config.yaml # 配置文件
10
- ├── README.md # 项目说明(HF Spaces 配置)
11
- ├── simple_e2e_test.py # 端到端测试
12
- ├── simple_test.py # 基本功能测试
13
- ├── data_processing/ # 数据处理模块
14
- │ ├── __init__.py
15
- │ ├── data_loader.py # 数据加载器
16
- │ └── preprocessor.py # 文本预处理器
17
- ├── retriever/ # 检索模块
18
- │ ├── __init__.py
19
- │ ├── embedder.py # 嵌入生成器
20
- │ ├── faiss_index.py # FAISS 索引
21
- │ ├── retriever.py # 检索器
22
- │ └── reranker.py # 重排序器
23
- ├── generator/ # 生成模块
24
- │ ├── __init__.py
25
- │ ├── vllm_server.py # vLLM 服务器
26
- │ ├── prompt_templates.py # 提示模板
27
- │ └── safe_generate.py # 安全生成器
28
- ├── calibration/ # 校准模块
29
- │ ├── __init__.py
30
- │ ├── features.py # 特征提取
31
- │ ├── calibration_head.py # 校准头
32
- │ └── trainer.py # 训练器
33
- └── eval/ # 评估模块
34
- ├── __init__.py
35
- ├── eval_qa.py # QA 评估
36
- ├── eval_attr.py # 归因评估
37
- ├── eval_calib.py # 校准评估
38
- └── eval_system.py # 系统评估
39
- ```
40
-
41
- ## 🚀 核心功能
42
-
43
- ### 1. 数据处理 (`data_processing/`)
44
- - **DataLoader**: 加载 HF Datasets(HotpotQA, TriviaQA, Wikipedia)
45
- - **Preprocessor**: 文本清理、句子分割、词元化
46
-
47
- ### 2. 检索系统 (`retriever/`)
48
- - **Embedder**: 使用 BGE/E5 生成嵌入向量
49
- - **FAISSIndex**: 构建和搜索 FAISS 索引
50
- - **Retriever**: 批量检索相关文档
51
- - **Reranker**: 重排序提升检索质量
52
-
53
- ### 3. 生成系统 (`generator/`)
54
- - **VLLMServer**: vLLM 推理服务器
55
- - **SafeGenerator**: 风险感知的答案生成
56
- - **PromptTemplates**: 提示模板管理
57
-
58
- ### 4. 风险校准 (`calibration/`)
59
- - **RiskFeatureExtractor**: 提取 16 维风险特征
60
- - **CalibrationHead**: LogReg/MLP 校准头
61
- - **Trainer**: 校准头训练
62
-
63
- ### 5. 评估系统 (`eval/`)
64
- - **QAEvaluator**: EM/F1 评估
65
- - **AttributionEvaluator**: 引用归因评估
66
- - **CalibrationEvaluator**: 校准质量评估
67
- - **SystemEvaluator**: 系统性能评估
68
-
69
- ## 🎯 风险校准策略
70
-
71
- ### 风险特征 (16维)
72
- 1. **检索统计**: 相似度分数、方差、多样性
73
- 2. **覆盖特征**: Q&A 间的 token/实体重叠
74
- 3. **一致性特征**: 段落间语义相似度
75
- 4. **多样性特征**: 主题方差、段落多样性
76
-
77
- ### 自适应策略
78
- - **低风险 (r < 0.3)**: 正常生成
79
- - **中风险 (0.3 ≤ r < 0.7)**: 保守生成 + 强制引用
80
- - **高风险 (r ≥ 0.7)**: 非常保守或拒绝回答
81
-
82
- ## 📊 性能目标
83
-
84
- - **QA 准确率**: 相比 vanilla RAG 的 EM/F1 提升
85
- - **归因质量**: 引用精确率/召回率提升 8-12pt
86
- - **校准质量**: ECE 降低 30-40%
87
- - **系统吞吐**: vLLM 带来 2-3.5x 提升
88
-
89
- ## 🧪 测试验证
90
-
91
- ### 端到端测试 (`simple_e2e_test.py`)
92
- - ✅ 8/8 测试通过
93
- - ✅ 完整 RAG 流程验证
94
- - ✅ 所有核心功能正常
95
-
96
- ### 基本测试 (`simple_test.py`)
97
- - ✅ 模块导入测试
98
- - ✅ 基本功能验证
99
- - ✅ 配置检查
100
-
101
- ## 🚀 部署到 Hugging Face Spaces
102
-
103
- ### 1. 上传文件
104
- - 将整个 `safe_rag` 目录上传到 HF Spaces
105
- - 确保 `app.py` 在根目录
106
-
107
- ### 2. 配置 Spaces
108
- - SDK: Gradio
109
- - Hardware: GPU (推荐 A10G 或 A100)
110
- - Environment: Python 3.8+
111
-
112
- ### 3. 自动部署
113
- - HF Spaces 会自动安装依赖
114
- - 自动启动 `app.py`
115
- - 提供公共访问链接
116
-
117
- ## 📝 使用说明
118
-
119
- ### 本地运行
120
- ```bash
121
- # 安装依赖
122
- pip install -r requirements.txt
123
-
124
- # 运行测试
125
- python3 simple_e2e_test.py
126
-
127
- # 启动演示
128
- python3 app.py
129
- ```
130
-
131
- ### 在线演示
132
- 访问 Hugging Face Spaces 链接,体验交互式 RAG 系统。
133
-
134
- ## 🎉 项目状态
135
-
136
- ✅ **完成**: 所有核心模块实现
137
- ✅ **测试**: 端到端测试通过
138
- ✅ **简化**: 移除不必要的文件
139
- ✅ **就绪**: 可部署到 HF Spaces
140
-
141
- SafeRAG 项目已经准备好部署和使用了!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,108 +1,16 @@
1
  ---
2
- title: SafeRAG Demo
3
- emoji: 🤖
4
- colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.0.0
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
 
 
 
11
  ---
12
 
13
- # SafeRAG: High-Performance Calibrated RAG
14
-
15
- A production-ready Retrieval-Augmented Generation (RAG) system with risk calibration, built on the Hugging Face ecosystem.
16
-
17
- ## 🚀 Key Features
18
-
19
- - **Risk Calibration**: Multi-layer risk assessment with adaptive strategies
20
- - **High Performance**: Optimized for 2-3.5x throughput improvement
21
- - **Hugging Face Native**: Built on HF Datasets, Models, and Spaces
22
- - **Production Ready**: Complete pipeline with error handling and monitoring
23
-
24
- ## 🏗️ Architecture
25
-
26
- ```
27
- HF Datasets → Embedding (BGE/E5) → FAISS Index
28
- Query → Batched Retrieval → Evidence Selector → Generator (vLLM + gpt-oss-20b)
29
- → Risk Calibration → Adaptive Strategy → Output (Answer + Citations + Risk Score)
30
- ```
31
-
32
- ## 📊 Performance Targets
33
-
34
- - **QA Accuracy**: EM/F1 improvements over vanilla RAG
35
- - **Attribution**: +8-12pt improvement in citation precision/recall
36
- - **Calibration**: 30-40% reduction in ECE (Expected Calibration Error)
37
- - **Throughput**: 2-3.5x improvement with vLLM
38
-
39
- ## 🛠️ Quick Start
40
-
41
- ### Run Tests
42
- ```bash
43
- python3 simple_e2e_test.py
44
- ```
45
-
46
- ### Start Demo
47
- ```bash
48
- python3 app.py
49
- ```
50
-
51
- ## 📈 Evaluation
52
-
53
- The system has been tested with comprehensive end-to-end tests:
54
-
55
- - ✅ Text processing and sentence extraction
56
- - ✅ Embedding creation and similarity calculation
57
- - ✅ Passage retrieval and reranking
58
- - ✅ Risk feature extraction and prediction
59
- - ✅ Risk-aware answer generation
60
- - ✅ Evaluation metrics (EM, F1, ROUGE)
61
- - ✅ Complete end-to-end RAG pipeline
62
-
63
- ## 🔧 Configuration
64
-
65
- Key parameters in `config.yaml`:
66
-
67
- - **Risk Thresholds**: τ₁ = 0.3, τ₂ = 0.7
68
- - **Retrieval**: k = 20, rerank_k = 10
69
- - **Generation**: max_tokens = 512, temperature = 0.7
70
- - **Calibration**: 16 features, logistic regression
71
-
72
- ## 🎯 Risk Calibration
73
-
74
- ### Risk Features (16-dimensional)
75
- 1. **Retrieval Statistics**: Similarity scores, variance, diversity
76
- 2. **Coverage Features**: Token/entity overlap between Q&A
77
- 3. **Consistency Features**: Semantic similarity between passages
78
- 4. **Diversity Features**: Topic variance, passage diversity
79
-
80
- ### Adaptive Strategies
81
- - **Low Risk (r < τ₁)**: Normal generation
82
- - **Medium Risk (τ₁ ≤ r < τ₂)**: Conservative generation + citations
83
- - **High Risk (r ≥ τ₂)**: Very conservative or refuse
84
-
85
- ## 📚 Datasets
86
-
87
- - **HotpotQA**: Multi-hop reasoning with supporting facts
88
- - **TriviaQA**: Open-domain QA for general knowledge
89
- - **Wikipedia**: Knowledge base via HF Datasets
90
-
91
- ## 📄 Citation
92
-
93
- ```bibtex
94
- @article{safrag2024,
95
- title={SafeRAG: High-Performance Calibrated RAG with Risk Assessment},
96
- author={Your Name},
97
- journal={arXiv preprint},
98
- year={2024}
99
- }
100
- ```
101
-
102
- ## 📝 License
103
-
104
- Apache 2.0 License - see LICENSE file for details.
105
-
106
- ---
107
-
108
- **SafeRAG**: A production-ready RAG system with risk calibration, built on Hugging Face ecosystem.
 
1
  ---
2
+ title: Safe Rag
3
+ emoji: 💬
4
+ colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.42.0
8
  app_file: app.py
9
  pinned: false
10
+ hf_oauth: true
11
+ hf_oauth_scopes:
12
+ - inference-api
13
+ short_description: A High-Performance and Risk-Calibrated RAG system
14
  ---
15
 
16
+ An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
calibration/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- from .features import RiskFeatureExtractor
2
- from .calibration_head import CalibrationHead
3
- from .trainer import CalibrationTrainer
4
-
5
- __all__ = ['RiskFeatureExtractor', 'CalibrationHead', 'CalibrationTrainer']
 
 
 
 
 
 
calibration/calibration_head.py DELETED
@@ -1,210 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import numpy as np
4
- from sklearn.linear_model import LogisticRegression
5
- from sklearn.ensemble import RandomForestClassifier
6
- from sklearn.metrics import accuracy_score, roc_auc_score
7
- from typing import Dict, Any, List, Tuple
8
- import logging
9
- import joblib
10
- import os
11
-
12
- logger = logging.getLogger(__name__)
13
-
14
- class CalibrationHead:
15
- def __init__(self, model_type: str = "logistic", input_dim: int = 16):
16
- self.model_type = model_type
17
- self.input_dim = input_dim
18
- self.model = None
19
- self.is_trained = False
20
-
21
- def _create_model(self):
22
- """Create the calibration model"""
23
- if self.model_type == "logistic":
24
- self.model = LogisticRegression(
25
- random_state=42,
26
- max_iter=1000,
27
- class_weight='balanced'
28
- )
29
- elif self.model_type == "random_forest":
30
- self.model = RandomForestClassifier(
31
- n_estimators=100,
32
- random_state=42,
33
- class_weight='balanced'
34
- )
35
- elif self.model_type == "mlp":
36
- self.model = MLPCalibrationHead(self.input_dim)
37
- else:
38
- raise ValueError(f"Unknown model type: {self.model_type}")
39
-
40
- def train(self, X: np.ndarray, y: np.ndarray) -> Dict[str, float]:
41
- """Train the calibration model"""
42
- if self.model is None:
43
- self._create_model()
44
-
45
- if self.model_type in ["logistic", "random_forest"]:
46
- # Sklearn models
47
- self.model.fit(X, y)
48
-
49
- # Get predictions and metrics
50
- y_pred = self.model.predict(X)
51
- y_proba = self.model.predict_proba(X)[:, 1] if hasattr(self.model, 'predict_proba') else y_pred
52
-
53
- metrics = {
54
- 'accuracy': accuracy_score(y, y_pred),
55
- 'auc': roc_auc_score(y, y_proba) if len(np.unique(y)) > 1 else 0.0
56
- }
57
- else:
58
- # PyTorch models
59
- metrics = self._train_pytorch_model(X, y)
60
-
61
- self.is_trained = True
62
- logger.info(f"Trained {self.model_type} model with metrics: {metrics}")
63
- return metrics
64
-
65
- def predict_risk(self, features: Dict[str, Any]) -> float:
66
- """Predict risk score from features"""
67
- if not self.is_trained:
68
- logger.warning("Model not trained, returning default risk score")
69
- return 0.5
70
-
71
- # Convert features to array
72
- X = self._features_to_array(features)
73
-
74
- if self.model_type in ["logistic", "random_forest"]:
75
- if hasattr(self.model, 'predict_proba'):
76
- risk_score = self.model.predict_proba(X.reshape(1, -1))[0, 1]
77
- else:
78
- risk_score = float(self.model.predict(X.reshape(1, -1))[0])
79
- else:
80
- # PyTorch models
81
- with torch.no_grad():
82
- X_tensor = torch.FloatTensor(X.reshape(1, -1))
83
- risk_score = torch.sigmoid(self.model(X_tensor)).item()
84
-
85
- return float(risk_score)
86
-
87
- def predict_batch(self, features_list: List[Dict[str, Any]]) -> List[float]:
88
- """Predict risk scores for multiple feature sets"""
89
- if not features_list:
90
- return []
91
-
92
- # Convert all features to arrays
93
- X = np.array([self._features_to_array(f) for f in features_list])
94
-
95
- if self.model_type in ["logistic", "random_forest"]:
96
- if hasattr(self.model, 'predict_proba'):
97
- risk_scores = self.model.predict_proba(X)[:, 1]
98
- else:
99
- risk_scores = self.model.predict(X)
100
- else:
101
- # PyTorch models
102
- with torch.no_grad():
103
- X_tensor = torch.FloatTensor(X)
104
- risk_scores = torch.sigmoid(self.model(X_tensor)).numpy()
105
-
106
- return risk_scores.tolist()
107
-
108
- def _features_to_array(self, features: Dict[str, Any]) -> np.ndarray:
109
- """Convert features dictionary to numpy array"""
110
- # Define feature order (must match training)
111
- feature_order = [
112
- 'num_passages', 'avg_similarity', 'std_similarity', 'max_similarity',
113
- 'min_similarity', 'score_variance', 'avg_token_overlap', 'max_token_overlap',
114
- 'avg_entity_overlap', 'max_entity_overlap', 'passage_consistency',
115
- 'passage_consistency_std', 'min_passage_similarity', 'diversity',
116
- 'topic_variance'
117
- ]
118
-
119
- # Extract features in order
120
- feature_array = []
121
- for feature_name in feature_order:
122
- value = features.get(feature_name, 0.0)
123
- feature_array.append(float(value))
124
-
125
- return np.array(feature_array)
126
-
127
- def _train_pytorch_model(self, X: np.ndarray, y: np.ndarray) -> Dict[str, float]:
128
- """Train PyTorch model"""
129
- # Convert to tensors
130
- X_tensor = torch.FloatTensor(X)
131
- y_tensor = torch.FloatTensor(y)
132
-
133
- # Training setup
134
- optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
135
- criterion = nn.BCEWithLogitsLoss()
136
-
137
- # Training loop
138
- self.model.train()
139
- for epoch in range(100):
140
- optimizer.zero_grad()
141
- outputs = self.model(X_tensor)
142
- loss = criterion(outputs.squeeze(), y_tensor)
143
- loss.backward()
144
- optimizer.step()
145
-
146
- # Evaluation
147
- self.model.eval()
148
- with torch.no_grad():
149
- outputs = self.model(X_tensor)
150
- predictions = torch.sigmoid(outputs).squeeze().numpy()
151
- binary_preds = (predictions > 0.5).astype(int)
152
-
153
- metrics = {
154
- 'accuracy': accuracy_score(y, binary_preds),
155
- 'auc': roc_auc_score(y, predictions) if len(np.unique(y)) > 1 else 0.0
156
- }
157
-
158
- return metrics
159
-
160
- def save(self, path: str) -> None:
161
- """Save the trained model"""
162
- os.makedirs(os.path.dirname(path), exist_ok=True)
163
-
164
- if self.model_type in ["logistic", "random_forest"]:
165
- joblib.dump(self.model, f"{path}.joblib")
166
- else:
167
- torch.save(self.model.state_dict(), f"{path}.pth")
168
-
169
- # Save metadata
170
- metadata = {
171
- 'model_type': self.model_type,
172
- 'input_dim': self.input_dim,
173
- 'is_trained': self.is_trained
174
- }
175
- joblib.dump(metadata, f"{path}_metadata.joblib")
176
-
177
- logger.info(f"Saved model to {path}")
178
-
179
- def load(self, path: str) -> None:
180
- """Load a trained model"""
181
- # Load metadata
182
- metadata = joblib.load(f"{path}_metadata.joblib")
183
- self.model_type = metadata['model_type']
184
- self.input_dim = metadata['input_dim']
185
- self.is_trained = metadata['is_trained']
186
-
187
- # Load model
188
- if self.model_type in ["logistic", "random_forest"]:
189
- self.model = joblib.load(f"{path}.joblib")
190
- else:
191
- self.model = MLPCalibrationHead(self.input_dim)
192
- self.model.load_state_dict(torch.load(f"{path}.pth"))
193
-
194
- logger.info(f"Loaded model from {path}")
195
-
196
- class MLPCalibrationHead(nn.Module):
197
- def __init__(self, input_dim: int, hidden_dim: int = 64):
198
- super().__init__()
199
- self.layers = nn.Sequential(
200
- nn.Linear(input_dim, hidden_dim),
201
- nn.ReLU(),
202
- nn.Dropout(0.2),
203
- nn.Linear(hidden_dim, hidden_dim // 2),
204
- nn.ReLU(),
205
- nn.Dropout(0.2),
206
- nn.Linear(hidden_dim // 2, 1)
207
- )
208
-
209
- def forward(self, x):
210
- return self.layers(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
calibration/features.py DELETED
@@ -1,173 +0,0 @@
1
- from typing import List, Dict, Any
2
- import numpy as np
3
- from sentence_transformers import SentenceTransformer
4
- import logging
5
- from sklearn.feature_extraction.text import TfidfVectorizer
6
- from sklearn.metrics.pairwise import cosine_similarity
7
- import re
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
- class RiskFeatureExtractor:
12
- def __init__(self, embedding_model: str = "BAAI/bge-large-en-v1.5"):
13
- self.embedding_model = SentenceTransformer(embedding_model)
14
- self.tfidf_vectorizer = TfidfVectorizer(max_features=1000, stop_words='english')
15
-
16
- def extract_features(self, question: str, retrieved_passages: List[Dict[str, Any]]) -> Dict[str, Any]:
17
- """Extract risk assessment features"""
18
- if not retrieved_passages:
19
- return self._get_empty_features()
20
-
21
- features = {}
22
-
23
- # Retrieval statistics
24
- features.update(self._extract_retrieval_stats(retrieved_passages))
25
-
26
- # Coverage features
27
- features.update(self._extract_coverage_features(question, retrieved_passages))
28
-
29
- # Consistency features
30
- features.update(self._extract_consistency_features(question, retrieved_passages))
31
-
32
- # Diversity features
33
- features.update(self._extract_diversity_features(retrieved_passages))
34
-
35
- return features
36
-
37
- def _extract_retrieval_stats(self, passages: List[Dict[str, Any]]) -> Dict[str, Any]:
38
- """Extract retrieval statistics"""
39
- if not passages:
40
- return {}
41
-
42
- scores = [p.get('score', 0.0) for p in passages]
43
-
44
- return {
45
- 'num_passages': len(passages),
46
- 'avg_similarity': np.mean(scores),
47
- 'std_similarity': np.std(scores),
48
- 'max_similarity': np.max(scores),
49
- 'min_similarity': np.min(scores),
50
- 'score_variance': np.var(scores)
51
- }
52
-
53
- def _extract_coverage_features(self, question: str, passages: List[Dict[str, Any]]) -> Dict[str, Any]:
54
- """Extract coverage features between question and passages"""
55
- if not passages:
56
- return {}
57
-
58
- # Token overlap
59
- question_tokens = set(question.lower().split())
60
- passage_texts = [p.get('text', '') for p in passages]
61
-
62
- overlaps = []
63
- for passage_text in passage_texts:
64
- passage_tokens = set(passage_text.lower().split())
65
- overlap = len(question_tokens.intersection(passage_tokens))
66
- overlaps.append(overlap / len(question_tokens) if question_tokens else 0)
67
-
68
- # Entity overlap (simplified)
69
- question_entities = self._extract_entities(question)
70
- entity_overlaps = []
71
-
72
- for passage_text in passage_texts:
73
- passage_entities = self._extract_entities(passage_text)
74
- overlap = len(question_entities.intersection(passage_entities))
75
- entity_overlaps.append(overlap / len(question_entities) if question_entities else 0)
76
-
77
- return {
78
- 'avg_token_overlap': np.mean(overlaps),
79
- 'max_token_overlap': np.max(overlaps),
80
- 'avg_entity_overlap': np.mean(entity_overlaps),
81
- 'max_entity_overlap': np.max(entity_overlaps)
82
- }
83
-
84
- def _extract_consistency_features(self, question: str, passages: List[Dict[str, Any]]) -> Dict[str, Any]:
85
- """Extract consistency features between passages"""
86
- if len(passages) < 2:
87
- return {'passage_consistency': 1.0}
88
-
89
- # Semantic similarity between passages
90
- passage_texts = [p.get('text', '') for p in passages]
91
- embeddings = self.embedding_model.encode(passage_texts)
92
-
93
- # Compute pairwise similarities
94
- similarities = cosine_similarity(embeddings)
95
-
96
- # Get upper triangle (excluding diagonal)
97
- upper_triangle = similarities[np.triu_indices_from(similarities, k=1)]
98
-
99
- return {
100
- 'passage_consistency': np.mean(upper_triangle),
101
- 'passage_consistency_std': np.std(upper_triangle),
102
- 'min_passage_similarity': np.min(upper_triangle)
103
- }
104
-
105
- def _extract_diversity_features(self, passages: List[Dict[str, Any]]) -> Dict[str, Any]:
106
- """Extract diversity features"""
107
- if len(passages) < 2:
108
- return {'diversity': 1.0}
109
-
110
- # Topic diversity using TF-IDF
111
- passage_texts = [p.get('text', '') for p in passages]
112
-
113
- try:
114
- tfidf_matrix = self.tfidf_vectorizer.fit_transform(passage_texts)
115
- similarities = cosine_similarity(tfidf_matrix)
116
-
117
- # Diversity as 1 - average similarity
118
- upper_triangle = similarities[np.triu_indices_from(similarities, k=1)]
119
- diversity = 1.0 - np.mean(upper_triangle)
120
-
121
- return {
122
- 'diversity': diversity,
123
- 'topic_variance': np.var(upper_triangle)
124
- }
125
- except:
126
- return {'diversity': 0.5, 'topic_variance': 0.0}
127
-
128
- def _extract_entities(self, text: str) -> set:
129
- """Extract entities from text (simplified)"""
130
- # Simple entity extraction - in practice use NER
131
- # Look for capitalized words and common entity patterns
132
- entities = set()
133
-
134
- # Capitalized words (potential entities)
135
- capitalized = re.findall(r'\b[A-Z][a-z]+\b', text)
136
- entities.update(capitalized)
137
-
138
- # Numbers and dates
139
- numbers = re.findall(r'\b\d+\b', text)
140
- entities.update(numbers)
141
-
142
- return entities
143
-
144
- def _get_empty_features(self) -> Dict[str, Any]:
145
- """Return empty features when no passages available"""
146
- return {
147
- 'num_passages': 0,
148
- 'avg_similarity': 0.0,
149
- 'std_similarity': 0.0,
150
- 'max_similarity': 0.0,
151
- 'min_similarity': 0.0,
152
- 'score_variance': 0.0,
153
- 'avg_token_overlap': 0.0,
154
- 'max_token_overlap': 0.0,
155
- 'avg_entity_overlap': 0.0,
156
- 'max_entity_overlap': 0.0,
157
- 'passage_consistency': 0.0,
158
- 'passage_consistency_std': 0.0,
159
- 'min_passage_similarity': 0.0,
160
- 'diversity': 0.0,
161
- 'topic_variance': 0.0
162
- }
163
-
164
- def extract_batch_features(self, questions: List[str],
165
- passages_list: List[List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
166
- """Extract features for multiple question-passage pairs"""
167
- features_list = []
168
-
169
- for question, passages in zip(questions, passages_list):
170
- features = self.extract_features(question, passages)
171
- features_list.append(features)
172
-
173
- return features_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
calibration/trainer.py DELETED
@@ -1,171 +0,0 @@
1
- from typing import List, Dict, Any, Tuple
2
- import numpy as np
3
- from sklearn.model_selection import train_test_split
4
- from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
5
- import logging
6
- from .features import RiskFeatureExtractor
7
- from .calibration_head import CalibrationHead
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
- class CalibrationTrainer:
12
- def __init__(self, feature_extractor: RiskFeatureExtractor,
13
- calibration_head: CalibrationHead):
14
- self.feature_extractor = feature_extractor
15
- self.calibration_head = calibration_head
16
-
17
- def prepare_training_data(self, qa_data: List[Dict[str, Any]],
18
- retrieved_passages_list: List[List[Dict[str, Any]]],
19
- labels: List[int]) -> Tuple[np.ndarray, np.ndarray]:
20
- """Prepare training data from QA samples and retrieved passages"""
21
-
22
- # Extract features
23
- features_list = self.feature_extractor.extract_batch_features(
24
- [item['question'] for item in qa_data],
25
- retrieved_passages_list
26
- )
27
-
28
- # Convert features to arrays
29
- X = np.array([self.feature_extractor._features_to_array(f) for f in features_list])
30
- y = np.array(labels)
31
-
32
- logger.info(f"Prepared training data: {X.shape[0]} samples, {X.shape[1]} features")
33
- return X, y
34
-
35
- def train(self, X: np.ndarray, y: np.ndarray,
36
- test_size: float = 0.2, random_state: int = 42) -> Dict[str, Any]:
37
- """Train the calibration model"""
38
-
39
- # Split data
40
- X_train, X_test, y_train, y_test = train_test_split(
41
- X, y, test_size=test_size, random_state=random_state, stratify=y
42
- )
43
-
44
- # Train model
45
- train_metrics = self.calibration_head.train(X_train, y_train)
46
-
47
- # Evaluate on test set
48
- test_metrics = self.evaluate(X_test, y_test)
49
-
50
- # Combine metrics
51
- all_metrics = {
52
- 'train': train_metrics,
53
- 'test': test_metrics,
54
- 'train_size': len(X_train),
55
- 'test_size': len(X_test)
56
- }
57
-
58
- logger.info(f"Training completed. Test metrics: {test_metrics}")
59
- return all_metrics
60
-
61
- def evaluate(self, X: np.ndarray, y: np.ndarray) -> Dict[str, float]:
62
- """Evaluate the calibration model"""
63
- if not self.calibration_head.is_trained:
64
- raise ValueError("Model not trained yet")
65
-
66
- # Get predictions
67
- if hasattr(self.calibration_head.model, 'predict_proba'):
68
- y_proba = self.calibration_head.model.predict_proba(X)[:, 1]
69
- y_pred = (y_proba > 0.5).astype(int)
70
- else:
71
- y_pred = self.calibration_head.model.predict(X)
72
- y_proba = y_pred
73
-
74
- # Calculate metrics
75
- accuracy = accuracy_score(y, y_pred)
76
- precision, recall, f1, _ = precision_recall_fscore_support(y, y_pred, average='binary')
77
-
78
- try:
79
- auc = roc_auc_score(y, y_proba)
80
- except:
81
- auc = 0.0
82
-
83
- return {
84
- 'accuracy': accuracy,
85
- 'precision': precision,
86
- 'recall': recall,
87
- 'f1': f1,
88
- 'auc': auc
89
- }
90
-
91
- def create_synthetic_labels(self, qa_data: List[Dict[str, Any]],
92
- retrieved_passages_list: List[List[Dict[str, Any]]]) -> List[int]:
93
- """Create synthetic risk labels for training (placeholder implementation)"""
94
- labels = []
95
-
96
- for qa_item, passages in zip(qa_data, retrieved_passages_list):
97
- # Simple heuristic for risk labeling
98
- # In practice, this would be based on human annotations or automated evaluation
99
-
100
- question = qa_item['question']
101
- answer = qa_item['answer']
102
-
103
- # Risk factors
104
- risk_score = 0.0
105
-
106
- # Low similarity scores = high risk
107
- if passages:
108
- avg_similarity = np.mean([p.get('score', 0.0) for p in passages])
109
- if avg_similarity < 0.3:
110
- risk_score += 0.3
111
-
112
- # Few passages = high risk
113
- if len(passages) < 3:
114
- risk_score += 0.2
115
-
116
- # Question complexity (length, question words)
117
- if len(question.split()) > 20:
118
- risk_score += 0.1
119
-
120
- if any(word in question.lower() for word in ['why', 'how', 'explain', 'compare']):
121
- risk_score += 0.1
122
-
123
- # Answer length (very short or very long answers might be risky)
124
- if len(answer.split()) < 5 or len(answer.split()) > 100:
125
- risk_score += 0.1
126
-
127
- # Convert to binary label
128
- label = 1 if risk_score > 0.3 else 0
129
- labels.append(label)
130
-
131
- logger.info(f"Created {sum(labels)} high-risk labels out of {len(labels)} total")
132
- return labels
133
-
134
- def cross_validate(self, X: np.ndarray, y: np.ndarray,
135
- cv_folds: int = 5) -> Dict[str, List[float]]:
136
- """Perform cross-validation"""
137
- from sklearn.model_selection import StratifiedKFold
138
-
139
- skf = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=42)
140
-
141
- fold_metrics = {
142
- 'accuracy': [],
143
- 'precision': [],
144
- 'recall': [],
145
- 'f1': [],
146
- 'auc': []
147
- }
148
-
149
- for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
150
- logger.info(f"Training fold {fold + 1}/{cv_folds}")
151
-
152
- X_train, X_val = X[train_idx], X[val_idx]
153
- y_train, y_val = y[train_idx], y[val_idx]
154
-
155
- # Train on fold
156
- self.calibration_head.train(X_train, y_train)
157
-
158
- # Evaluate on validation set
159
- val_metrics = self.evaluate(X_val, y_val)
160
-
161
- for metric, value in val_metrics.items():
162
- fold_metrics[metric].append(value)
163
-
164
- # Calculate mean and std
165
- cv_results = {}
166
- for metric, values in fold_metrics.items():
167
- cv_results[f'{metric}_mean'] = np.mean(values)
168
- cv_results[f'{metric}_std'] = np.std(values)
169
-
170
- logger.info(f"Cross-validation results: {cv_results}")
171
- return cv_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.yaml DELETED
@@ -1,189 +0,0 @@
1
- # SafeRAG Configuration File
2
-
3
- # Model Configuration
4
- models:
5
- embedding:
6
- name: "BAAI/bge-large-en-v1.5"
7
- device: "cuda"
8
- batch_size: 32
9
-
10
- reranker:
11
- name: "cross-encoder/ms-marco-MiniLM-L-6-v2"
12
- device: "cuda"
13
- batch_size: 32
14
-
15
- generator:
16
- name: "openai/gpt-oss-20b"
17
- tensor_parallel_size: 1
18
- gpu_memory_utilization: 0.9
19
- max_tokens: 512
20
- temperature: 0.7
21
- top_p: 0.9
22
-
23
- calibration:
24
- type: "logistic" # logistic, random_forest, mlp
25
- input_dim: 16
26
- hidden_dim: 64
27
-
28
- # Data Configuration
29
- data:
30
- datasets:
31
- - "hotpotqa"
32
- - "triviaqa"
33
- - "nq_open"
34
-
35
- knowledge_base:
36
- name: "wikipedia"
37
- language: "en"
38
- date: "20231101"
39
-
40
- preprocessing:
41
- max_sentence_length: 512
42
- min_sentence_length: 20
43
- cache_dir: "./cache"
44
-
45
- # Index Configuration
46
- index:
47
- type: "ivf" # flat, ivf
48
- dimension: 1024
49
- nlist: 4096
50
- save_path: "./index/safrag"
51
-
52
- # Retrieval Configuration
53
- retrieval:
54
- k: 20
55
- rerank_k: 10
56
- batch_size: 32
57
- similarity_threshold: 0.3
58
-
59
- # Risk Calibration Configuration
60
- calibration:
61
- tau1: 0.3 # Low risk threshold
62
- tau2: 0.7 # High risk threshold
63
-
64
- features:
65
- - "num_passages"
66
- - "avg_similarity"
67
- - "std_similarity"
68
- - "max_similarity"
69
- - "min_similarity"
70
- - "score_variance"
71
- - "avg_token_overlap"
72
- - "max_token_overlap"
73
- - "avg_entity_overlap"
74
- - "max_entity_overlap"
75
- - "passage_consistency"
76
- - "passage_consistency_std"
77
- - "min_passage_similarity"
78
- - "diversity"
79
- - "topic_variance"
80
-
81
- # Evaluation Configuration
82
- evaluation:
83
- metrics:
84
- qa:
85
- - "exact_match"
86
- - "f1"
87
- - "rouge1"
88
- - "rouge2"
89
- - "rougeL"
90
-
91
- attribution:
92
- - "precision"
93
- - "recall"
94
- - "f1"
95
- - "citation_coverage"
96
- - "citation_accuracy"
97
-
98
- calibration:
99
- - "ece"
100
- - "mce"
101
- - "auroc"
102
- - "auprc"
103
-
104
- system:
105
- - "throughput"
106
- - "latency"
107
- - "gpu_utilization"
108
- - "memory_usage"
109
-
110
- test_size: 0.2
111
- random_state: 42
112
- cv_folds: 5
113
-
114
- # System Configuration
115
- system:
116
- device: "cuda"
117
- num_workers: 4
118
- batch_size: 32
119
- max_memory_gb: 16
120
-
121
- monitoring:
122
- enabled: true
123
- interval: 1 # seconds
124
- metrics:
125
- - "cpu"
126
- - "memory"
127
- - "gpu"
128
- - "disk"
129
-
130
- # Output Configuration
131
- output:
132
- results_dir: "./results"
133
- logs_dir: "./logs"
134
- models_dir: "./models"
135
- plots_dir: "./plots"
136
-
137
- formats:
138
- - "json"
139
- - "csv"
140
- - "html"
141
-
142
- save_predictions: true
143
- save_features: true
144
- save_plots: true
145
-
146
- # Logging Configuration
147
- logging:
148
- level: "INFO"
149
- format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
150
- file: "./logs/safrag.log"
151
- max_size: "10MB"
152
- backup_count: 5
153
-
154
- # Hugging Face Configuration
155
- huggingface:
156
- cache_dir: "./cache"
157
- token: null # Set your HF token here
158
- hub_url: "https://huggingface.co"
159
-
160
- spaces:
161
- app_name: "safrag-demo"
162
- hardware: "cpu" # cpu, gpu, cpu-basic, gpu-basic
163
- visibility: "public"
164
-
165
- # Experiment Configuration
166
- experiments:
167
- baseline:
168
- enabled: true
169
- output_dir: "./results/baseline"
170
-
171
- safrag:
172
- enabled: true
173
- output_dir: "./results/safrag"
174
-
175
- ablation:
176
- enabled: true
177
- output_dir: "./results/ablation"
178
-
179
- studies:
180
- - "no_reranking"
181
- - "no_calibration"
182
- - "different_embeddings"
183
- - "different_thresholds"
184
- - "different_calibration_models"
185
- - "different_retrieval_k"
186
-
187
- comprehensive:
188
- enabled: true
189
- output_dir: "./results/comprehensive"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_processing/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .data_loader import DataLoader
2
- from .preprocessor import Preprocessor
3
-
4
- __all__ = ['DataLoader', 'Preprocessor']
 
 
 
 
 
data_processing/data_loader.py DELETED
@@ -1,29 +0,0 @@
1
-
2
- import logging
3
- from datasets import load_dataset
4
-
5
- logger = logging.getLogger(__name__)
6
-
7
- class DataLoader:
8
- def __init__(self, cache_dir: str = "./cache"):
9
- self.cache_dir = cache_dir
10
-
11
- def load_msmarco_passage(self, split: str = "train"):
12
- """Load MS MARCO Passage Ranking dataset from Hugging Face (v2.1)"""
13
- try:
14
- logger.info(f"Downloading MS MARCO Passage Ranking {split} (v2.1) from Hugging Face")
15
- ds = load_dataset("ms_marco", "v2.1", split=split)
16
- return ds
17
- except Exception as e:
18
- logger.error(f"Failed to load MS MARCO Passage Ranking: {e}")
19
- raise
20
-
21
- def get_passage_dataset(self, split: str = "train"):
22
- """Load MS MARCO Passage Ranking dataset"""
23
- try:
24
- ds = self.load_msmarco_passage(split)
25
- logger.info("MS MARCO Passage Ranking loaded successfully")
26
- return ds
27
- except Exception as e:
28
- logger.error(f"Failed to load MS MARCO Passage Ranking: {e}")
29
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_processing/preprocessor.py DELETED
@@ -1,112 +0,0 @@
1
- from typing import List, Dict, Any
2
- import re
3
- import logging
4
-
5
- logger = logging.getLogger(__name__)
6
-
7
- class Preprocessor:
8
- def __init__(self):
9
- """Initialize preprocessor without external dependencies"""
10
- pass
11
-
12
- def clean_text(self, text: str) -> str:
13
- """Clean and normalize text"""
14
- if not text:
15
- return ""
16
-
17
- # Remove extra whitespace
18
- text = text.strip()
19
- text = re.sub(r'\s+', ' ', text)
20
-
21
- # Remove special characters but keep punctuation
22
- text = re.sub(r'[^\w\s\.\,\!\?\;\:\-\(\)]', '', text)
23
-
24
- return text.strip()
25
-
26
- def extract_sentences(self, text: str) -> List[str]:
27
- """Extract sentences from text (simplified version without NLTK)"""
28
- if not text:
29
- return []
30
-
31
- # Simple sentence splitting based on punctuation
32
- sentences = re.split(r'[.!?]+', text)
33
- sentences = [s.strip() for s in sentences if s.strip()]
34
-
35
- return sentences
36
-
37
- def tokenize(self, text: str) -> List[str]:
38
- """Tokenize text into words (simplified version)"""
39
- if not text:
40
- return []
41
-
42
- # Simple word tokenization
43
- words = re.findall(r'\b\w+\b', text.lower())
44
- return words
45
-
46
- def preprocess_passages(self, passages: List[str]) -> List[Dict[str, Any]]:
47
- """Preprocess a list of passages"""
48
- processed = []
49
-
50
- for i, passage in enumerate(passages):
51
- if not passage:
52
- continue
53
-
54
- cleaned = self.clean_text(passage)
55
- sentences = self.extract_sentences(cleaned)
56
- tokens = self.tokenize(cleaned)
57
-
58
- processed.append({
59
- 'id': i,
60
- 'text': cleaned,
61
- 'sentences': sentences,
62
- 'tokens': tokens,
63
- 'length': len(tokens)
64
- })
65
-
66
- return processed
67
-
68
- def preprocess_qa_data(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
69
- """Preprocess QA data, auto convert dict/list fields to string"""
70
- processed = []
71
- def to_str(val):
72
- if isinstance(val, dict):
73
- # 拼接所有value
74
- return " ".join([to_str(v) for v in val.values()])
75
- elif isinstance(val, list):
76
- return " ".join([to_str(v) for v in val])
77
- elif val is None:
78
- return ""
79
- return str(val)
80
-
81
- for item in data:
82
- if not isinstance(item, dict):
83
- continue
84
- question = to_str(item.get('question', ''))
85
- answer = to_str(item.get('answer', ''))
86
- context = to_str(item.get('context', ''))
87
-
88
- processed_item = {
89
- 'question': self.clean_text(question),
90
- 'answer': self.clean_text(answer),
91
- 'context': self.clean_text(context),
92
- 'question_tokens': self.tokenize(question),
93
- 'answer_tokens': self.tokenize(answer),
94
- 'context_tokens': self.tokenize(context)
95
- }
96
- processed.append(processed_item)
97
- return processed
98
-
99
- def create_chunks(self, text: str, chunk_size: int = 512, overlap: int = 50) -> List[str]:
100
- """Create overlapping text chunks"""
101
- if not text:
102
- return []
103
-
104
- tokens = self.tokenize(text)
105
- chunks = []
106
-
107
- for i in range(0, len(tokens), chunk_size - overlap):
108
- chunk_tokens = tokens[i:i + chunk_size]
109
- chunk_text = ' '.join(chunk_tokens)
110
- chunks.append(chunk_text)
111
-
112
- return chunks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- from .eval_qa import QAEvaluator
2
- from .eval_attr import AttributionEvaluator
3
- from .eval_calib import CalibrationEvaluator
4
- from .eval_system import SystemEvaluator
5
-
6
- __all__ = ['QAEvaluator', 'AttributionEvaluator', 'CalibrationEvaluator', 'SystemEvaluator']
 
 
 
 
 
 
 
eval/eval_attr.py DELETED
@@ -1,275 +0,0 @@
1
- from typing import List, Dict, Any, Set
2
- import numpy as np
3
- from sentence_transformers import SentenceTransformer
4
- from sklearn.metrics.pairwise import cosine_similarity
5
- import logging
6
-
7
- logger = logging.getLogger(__name__)
8
-
9
- class AttributionEvaluator:
10
- def __init__(self, embedding_model: str = "BAAI/bge-large-en-v1.5"):
11
- self.embedding_model = SentenceTransformer(embedding_model)
12
-
13
- def evaluate_attribution(self, answers: List[str],
14
- retrieved_passages: List[List[Dict[str, Any]]],
15
- supporting_facts: List[List[str]] = None) -> Dict[str, float]:
16
- """Evaluate attribution quality"""
17
-
18
- if not answers or not retrieved_passages:
19
- return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}
20
-
21
- precisions = []
22
- recalls = []
23
- f1_scores = []
24
-
25
- for answer, passages, facts in zip(answers, retrieved_passages, supporting_facts or [[]] * len(answers)):
26
- if not passages:
27
- precisions.append(0.0)
28
- recalls.append(0.0)
29
- f1_scores.append(0.0)
30
- continue
31
-
32
- # Extract passage texts
33
- passage_texts = [p.get('text', '') for p in passages]
34
-
35
- # Calculate attribution metrics
36
- if facts:
37
- # Use provided supporting facts
38
- precision, recall, f1 = self._calculate_attribution_metrics(
39
- answer, passage_texts, facts
40
- )
41
- else:
42
- # Use semantic similarity as proxy
43
- precision, recall, f1 = self._calculate_semantic_attribution(
44
- answer, passage_texts
45
- )
46
-
47
- precisions.append(precision)
48
- recalls.append(recall)
49
- f1_scores.append(f1)
50
-
51
- return {
52
- 'precision': np.mean(precisions),
53
- 'recall': np.mean(recalls),
54
- 'f1': np.mean(f1_scores),
55
- 'precision_std': np.std(precisions),
56
- 'recall_std': np.std(recalls),
57
- 'f1_std': np.std(f1_scores)
58
- }
59
-
60
- def _calculate_attribution_metrics(self, answer: str, passages: List[str],
61
- supporting_facts: List[str]) -> tuple:
62
- """Calculate attribution metrics using supporting facts"""
63
-
64
- # Find which passages contain supporting facts
65
- relevant_passages = set()
66
- for fact in supporting_facts:
67
- for i, passage in enumerate(passages):
68
- if self._passage_contains_fact(passage, fact):
69
- relevant_passages.add(i)
70
-
71
- # Calculate metrics
72
- total_passages = len(passages)
73
- relevant_count = len(relevant_passages)
74
-
75
- if total_passages == 0:
76
- return 0.0, 0.0, 0.0
77
-
78
- # Precision: relevant passages / total retrieved passages
79
- precision = relevant_count / total_passages
80
-
81
- # Recall: relevant passages / total supporting facts
82
- recall = relevant_count / len(supporting_facts) if supporting_facts else 0.0
83
-
84
- # F1 score
85
- f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
86
-
87
- return precision, recall, f1
88
-
89
- def _calculate_semantic_attribution(self, answer: str, passages: List[str]) -> tuple:
90
- """Calculate attribution using semantic similarity"""
91
-
92
- if not passages:
93
- return 0.0, 0.0, 0.0
94
-
95
- # Encode answer and passages
96
- answer_embedding = self.embedding_model.encode([answer])
97
- passage_embeddings = self.embedding_model.encode(passages)
98
-
99
- # Calculate similarities
100
- similarities = cosine_similarity(answer_embedding, passage_embeddings)[0]
101
-
102
- # Use threshold to determine relevant passages
103
- threshold = 0.3
104
- relevant_passages = similarities >= threshold
105
-
106
- # Calculate metrics
107
- total_passages = len(passages)
108
- relevant_count = np.sum(relevant_passages)
109
-
110
- precision = relevant_count / total_passages
111
- recall = relevant_count / total_passages # Simplified for semantic method
112
- f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
113
-
114
- return precision, recall, f1
115
-
116
- def _passage_contains_fact(self, passage: str, fact: str) -> bool:
117
- """Check if passage contains a supporting fact"""
118
- # Simple containment check
119
- fact_words = set(fact.lower().split())
120
- passage_words = set(passage.lower().split())
121
-
122
- # Check if most fact words are in passage
123
- overlap = len(fact_words & passage_words)
124
- return overlap >= len(fact_words) * 0.7
125
-
126
- def evaluate_citation_quality(self, answers: List[str],
127
- citations: List[List[Dict[str, Any]]]) -> Dict[str, float]:
128
- """Evaluate citation quality in answers"""
129
-
130
- if not answers or not citations:
131
- return {'citation_coverage': 0.0, 'citation_accuracy': 0.0}
132
-
133
- coverage_scores = []
134
- accuracy_scores = []
135
-
136
- for answer, answer_citations in zip(answers, citations):
137
- # Citation coverage: percentage of answer that is cited
138
- coverage = self._calculate_citation_coverage(answer, answer_citations)
139
- coverage_scores.append(coverage)
140
-
141
- # Citation accuracy: percentage of citations that are relevant
142
- accuracy = self._calculate_citation_accuracy(answer, answer_citations)
143
- accuracy_scores.append(accuracy)
144
-
145
- return {
146
- 'citation_coverage': np.mean(coverage_scores),
147
- 'citation_accuracy': np.mean(accuracy_scores),
148
- 'citation_coverage_std': np.std(coverage_scores),
149
- 'citation_accuracy_std': np.std(accuracy_scores)
150
- }
151
-
152
- def _calculate_citation_coverage(self, answer: str, citations: List[Dict[str, Any]]) -> float:
153
- """Calculate what percentage of answer is covered by citations"""
154
- if not citations:
155
- return 0.0
156
-
157
- # Simple heuristic: check if answer contains citation markers
158
- import re
159
- citation_markers = re.findall(r'\[\d+\]', answer)
160
-
161
- if not citation_markers:
162
- return 0.0
163
-
164
- # Estimate coverage based on citation density
165
- answer_length = len(answer.split())
166
- citation_density = len(citation_markers) / answer_length if answer_length > 0 else 0
167
-
168
- return min(1.0, citation_density * 10) # Scale factor
169
-
170
- def _calculate_citation_accuracy(self, answer: str, citations: List[Dict[str, Any]]) -> float:
171
- """Calculate accuracy of citations"""
172
- if not citations:
173
- return 0.0
174
-
175
- # Simple heuristic: check if cited passages are relevant to answer
176
- answer_words = set(answer.lower().split())
177
- relevant_citations = 0
178
-
179
- for citation in citations:
180
- citation_text = citation.get('text', '')
181
- citation_words = set(citation_text.lower().split())
182
-
183
- # Check word overlap
184
- overlap = len(answer_words & citation_words)
185
- if overlap >= 3: # Threshold for relevance
186
- relevant_citations += 1
187
-
188
- return relevant_citations / len(citations)
189
-
190
- def evaluate_retrieval_quality(self, queries: List[str],
191
- retrieved_passages: List[List[Dict[str, Any]]],
192
- relevant_passages: List[List[str]] = None) -> Dict[str, float]:
193
- """Evaluate retrieval quality"""
194
-
195
- if not queries or not retrieved_passages:
196
- return {'retrieval_precision': 0.0, 'retrieval_recall': 0.0, 'retrieval_f1': 0.0}
197
-
198
- precisions = []
199
- recalls = []
200
- f1_scores = []
201
-
202
- for query, passages, relevant in zip(queries, retrieved_passages, relevant_passages or [[]] * len(queries)):
203
- if not passages:
204
- precisions.append(0.0)
205
- recalls.append(0.0)
206
- f1_scores.append(0.0)
207
- continue
208
-
209
- # Calculate retrieval metrics
210
- if relevant:
211
- precision, recall, f1 = self._calculate_retrieval_metrics(passages, relevant)
212
- else:
213
- # Use semantic similarity as proxy
214
- precision, recall, f1 = self._calculate_semantic_retrieval(query, passages)
215
-
216
- precisions.append(precision)
217
- recalls.append(recall)
218
- f1_scores.append(f1)
219
-
220
- return {
221
- 'retrieval_precision': np.mean(precisions),
222
- 'retrieval_recall': np.mean(recalls),
223
- 'retrieval_f1': np.mean(f1_scores),
224
- 'retrieval_precision_std': np.std(precisions),
225
- 'retrieval_recall_std': np.std(recalls),
226
- 'retrieval_f1_std': np.std(f1_scores)
227
- }
228
-
229
- def _calculate_retrieval_metrics(self, passages: List[Dict[str, Any]],
230
- relevant_passages: List[str]) -> tuple:
231
- """Calculate retrieval metrics using ground truth"""
232
-
233
- retrieved_texts = [p.get('text', '') for p in passages]
234
-
235
- # Find relevant retrieved passages
236
- relevant_retrieved = 0
237
- for retrieved in retrieved_texts:
238
- for relevant in relevant_passages:
239
- if self._passage_contains_fact(retrieved, relevant):
240
- relevant_retrieved += 1
241
- break
242
-
243
- total_retrieved = len(passages)
244
- total_relevant = len(relevant_passages)
245
-
246
- precision = relevant_retrieved / total_retrieved if total_retrieved > 0 else 0.0
247
- recall = relevant_retrieved / total_relevant if total_relevant > 0 else 0.0
248
- f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
249
-
250
- return precision, recall, f1
251
-
252
- def _calculate_semantic_retrieval(self, query: str, passages: List[Dict[str, Any]]) -> tuple:
253
- """Calculate retrieval metrics using semantic similarity"""
254
-
255
- if not passages:
256
- return 0.0, 0.0, 0.0
257
-
258
- # Encode query and passages
259
- query_embedding = self.embedding_model.encode([query])
260
- passage_embeddings = self.embedding_model.encode([p.get('text', '') for p in passages])
261
-
262
- # Calculate similarities
263
- similarities = cosine_similarity(query_embedding, passage_embeddings)[0]
264
-
265
- # Use threshold to determine relevant passages
266
- threshold = 0.3
267
- relevant_count = np.sum(similarities >= threshold)
268
-
269
- total_retrieved = len(passages)
270
-
271
- precision = relevant_count / total_retrieved
272
- recall = relevant_count / total_retrieved # Simplified for semantic method
273
- f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
274
-
275
- return precision, recall, f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/eval_calib.py DELETED
@@ -1,269 +0,0 @@
1
- from typing import List, Dict, Any
2
- import numpy as np
3
- from sklearn.metrics import roc_auc_score, average_precision_score
4
- import matplotlib.pyplot as plt
5
- import logging
6
-
7
- logger = logging.getLogger(__name__)
8
-
9
- class CalibrationEvaluator:
10
- def __init__(self):
11
- pass
12
-
13
- def expected_calibration_error(self, predictions: List[float],
14
- labels: List[int], n_bins: int = 10) -> float:
15
- """Calculate Expected Calibration Error (ECE)"""
16
-
17
- if not predictions or not labels:
18
- return 0.0
19
-
20
- predictions = np.array(predictions)
21
- labels = np.array(labels)
22
-
23
- # Create bins
24
- bin_boundaries = np.linspace(0, 1, n_bins + 1)
25
- bin_lowers = bin_boundaries[:-1]
26
- bin_uppers = bin_boundaries[1:]
27
-
28
- ece = 0
29
- for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
30
- # Find predictions in this bin
31
- in_bin = (predictions > bin_lower) & (predictions <= bin_upper)
32
- prop_in_bin = in_bin.mean()
33
-
34
- if prop_in_bin > 0:
35
- # Calculate accuracy in this bin
36
- accuracy_in_bin = labels[in_bin].mean()
37
- avg_confidence_in_bin = predictions[in_bin].mean()
38
-
39
- # Add to ECE
40
- ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
41
-
42
- return ece
43
-
44
- def maximum_calibration_error(self, predictions: List[float],
45
- labels: List[int], n_bins: int = 10) -> float:
46
- """Calculate Maximum Calibration Error (MCE)"""
47
-
48
- if not predictions or not labels:
49
- return 0.0
50
-
51
- predictions = np.array(predictions)
52
- labels = np.array(labels)
53
-
54
- # Create bins
55
- bin_boundaries = np.linspace(0, 1, n_bins + 1)
56
- bin_lowers = bin_boundaries[:-1]
57
- bin_uppers = bin_boundaries[1:]
58
-
59
- mce = 0
60
- for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
61
- # Find predictions in this bin
62
- in_bin = (predictions > bin_lower) & (predictions <= bin_upper)
63
-
64
- if in_bin.sum() > 0:
65
- # Calculate accuracy in this bin
66
- accuracy_in_bin = labels[in_bin].mean()
67
- avg_confidence_in_bin = predictions[in_bin].mean()
68
-
69
- # Update MCE
70
- mce = max(mce, np.abs(avg_confidence_in_bin - accuracy_in_bin))
71
-
72
- return mce
73
-
74
- def reliability_diagram(self, predictions: List[float], labels: List[int],
75
- n_bins: int = 10, save_path: str = None) -> Dict[str, Any]:
76
- """Create reliability diagram"""
77
-
78
- if not predictions or not labels:
79
- return {}
80
-
81
- predictions = np.array(predictions)
82
- labels = np.array(labels)
83
-
84
- # Create bins
85
- bin_boundaries = np.linspace(0, 1, n_bins + 1)
86
- bin_lowers = bin_boundaries[:-1]
87
- bin_uppers = bin_boundaries[1:]
88
-
89
- bin_centers = []
90
- accuracies = []
91
- confidences = []
92
- counts = []
93
-
94
- for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
95
- # Find predictions in this bin
96
- in_bin = (predictions > bin_lower) & (predictions <= bin_upper)
97
- count = in_bin.sum()
98
-
99
- if count > 0:
100
- bin_center = (bin_lower + bin_upper) / 2
101
- accuracy = labels[in_bin].mean()
102
- confidence = predictions[in_bin].mean()
103
-
104
- bin_centers.append(bin_center)
105
- accuracies.append(accuracy)
106
- confidences.append(confidence)
107
- counts.append(count)
108
-
109
- # Create plot
110
- plt.figure(figsize=(8, 6))
111
- plt.bar(bin_centers, accuracies, width=0.1, alpha=0.7, label='Accuracy')
112
- plt.plot([0, 1], [0, 1], 'r--', label='Perfect Calibration')
113
- plt.xlabel('Confidence')
114
- plt.ylabel('Accuracy')
115
- plt.title('Reliability Diagram')
116
- plt.legend()
117
- plt.grid(True, alpha=0.3)
118
-
119
- if save_path:
120
- plt.savefig(save_path, dpi=300, bbox_inches='tight')
121
-
122
- plt.close()
123
-
124
- return {
125
- 'bin_centers': bin_centers,
126
- 'accuracies': accuracies,
127
- 'confidences': confidences,
128
- 'counts': counts
129
- }
130
-
131
- def auroc(self, predictions: List[float], labels: List[int]) -> float:
132
- """Calculate Area Under ROC Curve"""
133
- if not predictions or not labels:
134
- return 0.0
135
-
136
- try:
137
- return roc_auc_score(labels, predictions)
138
- except:
139
- return 0.0
140
-
141
- def auprc(self, predictions: List[float], labels: List[int]) -> float:
142
- """Calculate Area Under Precision-Recall Curve"""
143
- if not predictions or not labels:
144
- return 0.0
145
-
146
- try:
147
- return average_precision_score(labels, predictions)
148
- except:
149
- return 0.0
150
-
151
- def risk_coverage_curve(self, predictions: List[float], labels: List[int],
152
- risk_thresholds: List[float] = None) -> Dict[str, Any]:
153
- """Calculate risk-coverage curve"""
154
-
155
- if not predictions or not labels:
156
- return {'thresholds': [], 'coverage': [], 'accuracy': []}
157
-
158
- predictions = np.array(predictions)
159
- labels = np.array(labels)
160
-
161
- if risk_thresholds is None:
162
- risk_thresholds = np.linspace(0, 1, 21)
163
-
164
- coverages = []
165
- accuracies = []
166
-
167
- for threshold in risk_thresholds:
168
- # Select predictions with risk <= threshold
169
- selected = predictions <= threshold
170
-
171
- if selected.sum() > 0:
172
- coverage = selected.mean()
173
- accuracy = labels[selected].mean()
174
- else:
175
- coverage = 0.0
176
- accuracy = 0.0
177
-
178
- coverages.append(coverage)
179
- accuracies.append(accuracy)
180
-
181
- return {
182
- 'thresholds': risk_thresholds.tolist(),
183
- 'coverage': coverages,
184
- 'accuracy': accuracies
185
- }
186
-
187
- def evaluate_calibration(self, predictions: List[float], labels: List[int]) -> Dict[str, float]:
188
- """Comprehensive calibration evaluation"""
189
-
190
- if not predictions or not labels:
191
- return {
192
- 'ece': 0.0,
193
- 'mce': 0.0,
194
- 'auroc': 0.0,
195
- 'auprc': 0.0
196
- }
197
-
198
- metrics = {
199
- 'ece': self.expected_calibration_error(predictions, labels),
200
- 'mce': self.maximum_calibration_error(predictions, labels),
201
- 'auroc': self.auroc(predictions, labels),
202
- 'auprc': self.auprc(predictions, labels)
203
- }
204
-
205
- # Risk-coverage analysis
206
- risk_coverage = self.risk_coverage_curve(predictions, labels)
207
- metrics['risk_coverage'] = risk_coverage
208
-
209
- return metrics
210
-
211
- def plot_calibration_curves(self, predictions: List[float], labels: List[int],
212
- save_path: str = None) -> None:
213
- """Plot calibration curves"""
214
-
215
- if not predictions or not labels:
216
- return
217
-
218
- fig, axes = plt.subplots(2, 2, figsize=(12, 10))
219
-
220
- # Reliability diagram
221
- reliability_data = self.reliability_diagram(predictions, labels)
222
- if reliability_data:
223
- axes[0, 0].bar(reliability_data['bin_centers'], reliability_data['accuracies'],
224
- width=0.1, alpha=0.7)
225
- axes[0, 0].plot([0, 1], [0, 1], 'r--')
226
- axes[0, 0].set_xlabel('Confidence')
227
- axes[0, 0].set_ylabel('Accuracy')
228
- axes[0, 0].set_title('Reliability Diagram')
229
- axes[0, 0].grid(True, alpha=0.3)
230
-
231
- # Risk-coverage curve
232
- risk_coverage = self.risk_coverage_curve(predictions, labels)
233
- if risk_coverage['thresholds']:
234
- axes[0, 1].plot(risk_coverage['coverage'], risk_coverage['accuracy'], 'b-')
235
- axes[0, 1].set_xlabel('Coverage')
236
- axes[0, 1].set_ylabel('Accuracy')
237
- axes[0, 1].set_title('Risk-Coverage Curve')
238
- axes[0, 1].grid(True, alpha=0.3)
239
-
240
- # Confidence distribution
241
- axes[1, 0].hist(predictions, bins=20, alpha=0.7, edgecolor='black')
242
- axes[1, 0].set_xlabel('Confidence')
243
- axes[1, 0].set_ylabel('Count')
244
- axes[1, 0].set_title('Confidence Distribution')
245
- axes[1, 0].grid(True, alpha=0.3)
246
-
247
- # Accuracy vs Confidence
248
- bin_centers = np.linspace(0, 1, 11)
249
- accuracies = []
250
- for i in range(len(bin_centers) - 1):
251
- mask = (np.array(predictions) >= bin_centers[i]) & (np.array(predictions) < bin_centers[i + 1])
252
- if mask.sum() > 0:
253
- accuracies.append(np.array(labels)[mask].mean())
254
- else:
255
- accuracies.append(0)
256
-
257
- axes[1, 1].plot(bin_centers[:-1], accuracies, 'bo-')
258
- axes[1, 1].plot([0, 1], [0, 1], 'r--')
259
- axes[1, 1].set_xlabel('Confidence')
260
- axes[1, 1].set_ylabel('Accuracy')
261
- axes[1, 1].set_title('Accuracy vs Confidence')
262
- axes[1, 1].grid(True, alpha=0.3)
263
-
264
- plt.tight_layout()
265
-
266
- if save_path:
267
- plt.savefig(save_path, dpi=300, bbox_inches='tight')
268
-
269
- plt.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/eval_qa.py DELETED
@@ -1,137 +0,0 @@
1
- import re
2
- from typing import List, Dict, Any
3
- import numpy as np
4
- from evaluate import load
5
- import logging
6
-
7
- logger = logging.getLogger(__name__)
8
-
9
- class QAEvaluator:
10
- def __init__(self):
11
- self.squad_metric = load("squad")
12
- self.rouge_metric = load("rouge")
13
-
14
- def exact_match(self, predictions: List[str], references: List[str]) -> float:
15
- """Calculate exact match score"""
16
- matches = 0
17
- for pred, ref in zip(predictions, references):
18
- if self._normalize_answer(pred) == self._normalize_answer(ref):
19
- matches += 1
20
- return matches / len(predictions) if predictions else 0.0
21
-
22
- def f1_score(self, predictions: List[str], references: List[str]) -> float:
23
- """Calculate F1 score"""
24
- f1_scores = []
25
- for pred, ref in zip(predictions, references):
26
- f1 = self._calculate_f1(pred, ref)
27
- f1_scores.append(f1)
28
- return np.mean(f1_scores) if f1_scores else 0.0
29
-
30
- def rouge_score(self, predictions: List[str], references: List[str]) -> Dict[str, float]:
31
- """Calculate ROUGE scores"""
32
- if not predictions or not references:
33
- return {'rouge1': 0.0, 'rouge2': 0.0, 'rougeL': 0.0}
34
-
35
- results = self.rouge_metric.compute(
36
- predictions=predictions,
37
- references=references
38
- )
39
-
40
- return {
41
- 'rouge1': results['rouge1'],
42
- 'rouge2': results['rouge2'],
43
- 'rougeL': results['rougeL']
44
- }
45
-
46
- def squad_metrics(self, predictions: List[str], references: List[str]) -> Dict[str, float]:
47
- """Calculate SQuAD-style metrics"""
48
- if not predictions or not references:
49
- return {'exact_match': 0.0, 'f1': 0.0}
50
-
51
- # Format for SQuAD metric
52
- formatted_predictions = [{"prediction_text": pred, "id": str(i)}
53
- for i, pred in enumerate(predictions)]
54
- formatted_references = [{"answers": {"text": [ref], "answer_start": [0]}, "id": str(i)}
55
- for i, ref in enumerate(references)]
56
-
57
- results = self.squad_metric.compute(
58
- predictions=formatted_predictions,
59
- references=formatted_references
60
- )
61
-
62
- return {
63
- 'exact_match': results['exact_match'],
64
- 'f1': results['f1']
65
- }
66
-
67
- def evaluate_batch(self, predictions: List[str], references: List[str]) -> Dict[str, float]:
68
- """Evaluate a batch of predictions"""
69
- metrics = {}
70
-
71
- # Basic metrics
72
- metrics['exact_match'] = self.exact_match(predictions, references)
73
- metrics['f1'] = self.f1_score(predictions, references)
74
-
75
- # ROUGE metrics
76
- rouge_scores = self.rouge_score(predictions, references)
77
- metrics.update(rouge_scores)
78
-
79
- # SQuAD metrics
80
- squad_scores = self.squad_metrics(predictions, references)
81
- metrics.update(squad_scores)
82
-
83
- return metrics
84
-
85
- def _normalize_answer(self, answer: str) -> str:
86
- """Normalize answer for comparison"""
87
- def remove_articles(text):
88
- return re.sub(r'\b(a|an|the)\b', ' ', text)
89
-
90
- def white_space_fix(text):
91
- return ' '.join(text.split())
92
-
93
- def remove_punc(text):
94
- exclude = set(string.punctuation)
95
- return ''.join(ch for ch in text if ch not in exclude)
96
-
97
- def lower(text):
98
- return text.lower()
99
-
100
- return white_space_fix(remove_articles(remove_punc(lower(answer))))
101
-
102
- def _calculate_f1(self, prediction: str, reference: str) -> float:
103
- """Calculate F1 score between prediction and reference"""
104
- pred_tokens = self._normalize_answer(prediction).split()
105
- ref_tokens = self._normalize_answer(reference).split()
106
-
107
- if len(ref_tokens) == 0:
108
- return 1.0 if len(pred_tokens) == 0 else 0.0
109
-
110
- common = set(pred_tokens) & set(ref_tokens)
111
-
112
- if len(common) == 0:
113
- return 0.0
114
-
115
- precision = len(common) / len(pred_tokens)
116
- recall = len(common) / len(ref_tokens)
117
-
118
- f1 = 2 * precision * recall / (precision + recall)
119
- return f1
120
-
121
- def evaluate_with_context(self, predictions: List[str], references: List[str],
122
- contexts: List[str]) -> Dict[str, float]:
123
- """Evaluate with context awareness"""
124
- metrics = self.evaluate_batch(predictions, references)
125
-
126
- # Context-based metrics
127
- context_scores = []
128
- for pred, context in zip(predictions, contexts):
129
- # Check if prediction is supported by context
130
- pred_words = set(pred.lower().split())
131
- context_words = set(context.lower().split())
132
- overlap = len(pred_words & context_words) / len(pred_words) if pred_words else 0
133
- context_scores.append(overlap)
134
-
135
- metrics['context_support'] = np.mean(context_scores)
136
-
137
- return metrics
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/eval_system.py DELETED
@@ -1,297 +0,0 @@
1
- import time
2
- import psutil
3
- import GPUtil
4
- from typing import List, Dict, Any, Optional
5
- import numpy as np
6
- import logging
7
- import threading
8
- from concurrent.futures import ThreadPoolExecutor, as_completed
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
- class SystemEvaluator:
13
- def __init__(self):
14
- self.monitoring = False
15
- self.metrics = []
16
- self.monitor_thread = None
17
-
18
- def start_monitoring(self):
19
- """Start system monitoring"""
20
- self.monitoring = True
21
- self.metrics = []
22
- self.monitor_thread = threading.Thread(target=self._monitor_system)
23
- self.monitor_thread.start()
24
- logger.info("Started system monitoring")
25
-
26
- def stop_monitoring(self):
27
- """Stop system monitoring"""
28
- self.monitoring = False
29
- if self.monitor_thread:
30
- self.monitor_thread.join()
31
- logger.info("Stopped system monitoring")
32
-
33
- def _monitor_system(self):
34
- """Monitor system resources"""
35
- while self.monitoring:
36
- try:
37
- # CPU usage
38
- cpu_percent = psutil.cpu_percent(interval=1)
39
-
40
- # Memory usage
41
- memory = psutil.virtual_memory()
42
- memory_percent = memory.percent
43
- memory_used_gb = memory.used / (1024**3)
44
-
45
- # GPU usage (if available)
46
- gpu_metrics = self._get_gpu_metrics()
47
-
48
- # Disk usage
49
- disk = psutil.disk_usage('/')
50
- disk_percent = disk.percent
51
-
52
- metric = {
53
- 'timestamp': time.time(),
54
- 'cpu_percent': cpu_percent,
55
- 'memory_percent': memory_percent,
56
- 'memory_used_gb': memory_used_gb,
57
- 'disk_percent': disk_percent,
58
- **gpu_metrics
59
- }
60
-
61
- self.metrics.append(metric)
62
-
63
- except Exception as e:
64
- logger.error(f"Error monitoring system: {e}")
65
-
66
- time.sleep(1) # Monitor every second
67
-
68
- def _get_gpu_metrics(self) -> Dict[str, Any]:
69
- """Get GPU metrics"""
70
- try:
71
- gpus = GPUtil.getGPUs()
72
- if gpus:
73
- gpu = gpus[0] # Use first GPU
74
- return {
75
- 'gpu_utilization': gpu.load * 100,
76
- 'gpu_memory_used': gpu.memoryUsed,
77
- 'gpu_memory_total': gpu.memoryTotal,
78
- 'gpu_memory_percent': (gpu.memoryUsed / gpu.memoryTotal) * 100,
79
- 'gpu_temperature': gpu.temperature
80
- }
81
- except:
82
- pass
83
-
84
- return {
85
- 'gpu_utilization': 0,
86
- 'gpu_memory_used': 0,
87
- 'gpu_memory_total': 0,
88
- 'gpu_memory_percent': 0,
89
- 'gpu_temperature': 0
90
- }
91
-
92
- def measure_throughput(self, func, args_list: List[tuple],
93
- max_workers: int = 4) -> Dict[str, Any]:
94
- """Measure throughput of a function"""
95
-
96
- start_time = time.time()
97
-
98
- # Execute function with different concurrency levels
99
- results = []
100
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
101
- futures = [executor.submit(func, *args) for args in args_list]
102
-
103
- for future in as_completed(futures):
104
- try:
105
- result = future.result()
106
- results.append(result)
107
- except Exception as e:
108
- logger.error(f"Error in throughput measurement: {e}")
109
-
110
- end_time = time.time()
111
-
112
- total_time = end_time - start_time
113
- throughput = len(results) / total_time # queries per second
114
-
115
- return {
116
- 'total_queries': len(args_list),
117
- 'successful_queries': len(results),
118
- 'total_time': total_time,
119
- 'throughput_qps': throughput,
120
- 'avg_time_per_query': total_time / len(args_list) if args_list else 0
121
- }
122
-
123
- def measure_latency(self, func, args: tuple, num_runs: int = 10) -> Dict[str, Any]:
124
- """Measure latency of a function"""
125
-
126
- latencies = []
127
-
128
- for _ in range(num_runs):
129
- start_time = time.time()
130
- try:
131
- result = func(*args)
132
- end_time = time.time()
133
- latency = end_time - start_time
134
- latencies.append(latency)
135
- except Exception as e:
136
- logger.error(f"Error in latency measurement: {e}")
137
- latencies.append(float('inf'))
138
-
139
- # Remove infinite latencies
140
- latencies = [l for l in latencies if l != float('inf')]
141
-
142
- if not latencies:
143
- return {
144
- 'avg_latency': 0,
145
- 'p50_latency': 0,
146
- 'p95_latency': 0,
147
- 'p99_latency': 0,
148
- 'min_latency': 0,
149
- 'max_latency': 0,
150
- 'std_latency': 0
151
- }
152
-
153
- latencies = np.array(latencies)
154
-
155
- return {
156
- 'avg_latency': np.mean(latencies),
157
- 'p50_latency': np.percentile(latencies, 50),
158
- 'p95_latency': np.percentile(latencies, 95),
159
- 'p99_latency': np.percentile(latencies, 99),
160
- 'min_latency': np.min(latencies),
161
- 'max_latency': np.max(latencies),
162
- 'std_latency': np.std(latencies)
163
- }
164
-
165
- def measure_batch_latency(self, func, args_list: List[tuple],
166
- batch_sizes: List[int] = [1, 4, 8, 16]) -> Dict[str, Any]:
167
- """Measure latency for different batch sizes"""
168
-
169
- results = {}
170
-
171
- for batch_size in batch_sizes:
172
- batch_latencies = []
173
-
174
- # Process in batches
175
- for i in range(0, len(args_list), batch_size):
176
- batch_args = args_list[i:i + batch_size]
177
-
178
- start_time = time.time()
179
- try:
180
- batch_results = [func(*args) for args in batch_args]
181
- end_time = time.time()
182
-
183
- batch_latency = end_time - start_time
184
- batch_latencies.append(batch_latency)
185
-
186
- except Exception as e:
187
- logger.error(f"Error in batch latency measurement: {e}")
188
-
189
- if batch_latencies:
190
- results[f'batch_size_{batch_size}'] = {
191
- 'avg_latency': np.mean(batch_latencies),
192
- 'p95_latency': np.percentile(batch_latencies, 95),
193
- 'throughput': batch_size / np.mean(batch_latencies)
194
- }
195
-
196
- return results
197
-
198
- def get_system_stats(self) -> Dict[str, Any]:
199
- """Get current system statistics"""
200
-
201
- if not self.metrics:
202
- return {}
203
-
204
- # Calculate statistics from monitoring data
205
- cpu_values = [m['cpu_percent'] for m in self.metrics]
206
- memory_values = [m['memory_percent'] for m in self.metrics]
207
- gpu_values = [m.get('gpu_utilization', 0) for m in self.metrics]
208
-
209
- return {
210
- 'monitoring_duration': len(self.metrics),
211
- 'cpu': {
212
- 'avg': np.mean(cpu_values),
213
- 'max': np.max(cpu_values),
214
- 'min': np.min(cpu_values),
215
- 'std': np.std(cpu_values)
216
- },
217
- 'memory': {
218
- 'avg': np.mean(memory_values),
219
- 'max': np.max(memory_values),
220
- 'min': np.min(memory_values),
221
- 'std': np.std(memory_values)
222
- },
223
- 'gpu': {
224
- 'avg': np.mean(gpu_values),
225
- 'max': np.max(gpu_values),
226
- 'min': np.min(gpu_values),
227
- 'std': np.std(gpu_values)
228
- }
229
- }
230
-
231
- def evaluate_retrieval_performance(self, retriever, queries: List[str],
232
- k: int = 10) -> Dict[str, Any]:
233
- """Evaluate retrieval performance"""
234
-
235
- # Measure latency
236
- latency_stats = self.measure_latency(
237
- retriever.retrieve_single,
238
- (queries[0], k),
239
- num_runs=5
240
- )
241
-
242
- # Measure throughput
243
- throughput_stats = self.measure_throughput(
244
- retriever.retrieve_single,
245
- [(query, k) for query in queries[:10]], # Limit for throughput test
246
- max_workers=4
247
- )
248
-
249
- return {
250
- 'latency': latency_stats,
251
- 'throughput': throughput_stats
252
- }
253
-
254
- def evaluate_generation_performance(self, generator, questions: List[str],
255
- passages_list: List[List[Dict[str, Any]]]) -> Dict[str, Any]:
256
- """Evaluate generation performance"""
257
-
258
- # Measure latency
259
- latency_stats = self.measure_latency(
260
- generator.generate_with_strategy,
261
- (questions[0], passages_list[0]),
262
- num_runs=5
263
- )
264
-
265
- # Measure throughput
266
- throughput_stats = self.measure_throughput(
267
- generator.generate_with_strategy,
268
- list(zip(questions[:5], passages_list[:5])), # Limit for throughput test
269
- max_workers=2
270
- )
271
-
272
- return {
273
- 'latency': latency_stats,
274
- 'throughput': throughput_stats
275
- }
276
-
277
- def evaluate_end_to_end_performance(self, rag_system, queries: List[str]) -> Dict[str, Any]:
278
- """Evaluate end-to-end RAG performance"""
279
-
280
- # Measure latency
281
- latency_stats = self.measure_latency(
282
- rag_system.query,
283
- (queries[0],),
284
- num_runs=5
285
- )
286
-
287
- # Measure throughput
288
- throughput_stats = self.measure_throughput(
289
- rag_system.query,
290
- [(query,) for query in queries[:10]], # Limit for throughput test
291
- max_workers=2
292
- )
293
-
294
- return {
295
- 'latency': latency_stats,
296
- 'throughput': throughput_stats
297
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
exp_pipeline/pipeline.py DELETED
@@ -1,56 +0,0 @@
1
- """
2
- End-to-end pipeline for dataset download, preprocessing, embedding, and indexing.
3
- """
4
- import logging
5
- from data_processing.data_loader import DataLoader
6
- from data_processing.preprocessor import Preprocessor
7
- from retriever.embedder import Embedder
8
- from retriever.faiss_index import build_faiss_index
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
-
13
- def run_pipeline(split: str = "train"):
14
- # 1. 下载MS MARCO Passage Ranking数据集
15
- data_loader = DataLoader()
16
- raw_data = data_loader.get_passage_dataset(split)
17
- logger.info(f"Loaded {len(raw_data)} samples from MS MARCO Passage Ranking [{split}]")
18
- print("data_loader\n")
19
-
20
- # 2. 预处理数据
21
- preprocessor = Preprocessor()
22
- # HuggingFace datasets对象转list
23
- if hasattr(raw_data, "to_dict"):
24
- raw_data = raw_data.to_dict()
25
- raw_data = [dict(zip(raw_data.keys(), v)) for v in zip(*raw_data.values())]
26
- print("raw_data\n")
27
-
28
- # MS MARCO Passage v2.1: 用passages["passage_text"]字段
29
- passages = []
30
- for item in raw_data:
31
- if "passages" in item and "passage_text" in item["passages"]:
32
- passages.extend(item["passages"]["passage_text"])
33
- processed = preprocessor.preprocess_passages(passages)
34
- texts = [p["text"] for p in processed]
35
- print("texts\n")
36
-
37
- logger.info(f"Processed {len(texts)} passages")
38
-
39
- # 3. 生产embedding
40
- embedder = Embedder(device="cuda")
41
- embeddings = embedder.encode(texts)
42
- print(f"Embedding shape: {getattr(embeddings, 'shape', None)}")
43
- print(f"Texts count: {len(texts)}")
44
- if embeddings is None or not hasattr(embeddings, 'shape') or len(embeddings.shape) != 2 or embeddings.shape[0] == 0:
45
- raise ValueError("Embeddings is empty or not a 2D array. Check input texts and embedding model.")
46
-
47
- # 4. 建立FAISS索引
48
- index = build_faiss_index(embeddings, texts, index_type="HNSW")
49
- logger.info("FAISS index built successfully")
50
- # 持久化index到./index文件夹
51
- index.save("../index/msmarco_hnsw")
52
- logger.info("FAISS index saved to ./index/msmarco_hnsw")
53
- return index
54
-
55
- if __name__ == "__main__":
56
- run_pipeline("train")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
generator/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- from .vllm_server import VLLMServer
2
- from .safe_generate import SafeGenerator
3
- from .prompt_templates import PromptTemplates
4
-
5
- __all__ = ['VLLMServer', 'SafeGenerator', 'PromptTemplates']
 
 
 
 
 
 
generator/prompt_templates.py DELETED
@@ -1,113 +0,0 @@
1
- from typing import List, Dict, Any
2
- from dataclasses import dataclass
3
-
4
- @dataclass
5
- class PromptTemplate:
6
- name: str
7
- template: str
8
- system_prompt: str = ""
9
-
10
- class PromptTemplates:
11
- def __init__(self):
12
- self.templates = {
13
- 'rag': PromptTemplate(
14
- name='rag',
15
- system_prompt="You are a helpful assistant that answers questions based on provided context. Always cite your sources when possible.",
16
- template="""Context:
17
- {context}
18
-
19
- Question: {question}
20
-
21
- Answer:"""
22
- ),
23
-
24
- 'rag_with_citations': PromptTemplate(
25
- name='rag_with_citations',
26
- system_prompt="You are a helpful assistant that answers questions based on provided context. Always provide citations in the format [1], [2], etc.",
27
- template="""Context:
28
- {context}
29
-
30
- Question: {question}
31
-
32
- Answer (with citations):"""
33
- ),
34
-
35
- 'rag_safe': PromptTemplate(
36
- name='rag_safe',
37
- system_prompt="You are a helpful assistant that answers questions based on provided context. If you're uncertain, say so. Always cite your sources.",
38
- template="""Context:
39
- {context}
40
-
41
- Question: {question}
42
-
43
- Instructions:
44
- - Answer based on the provided context
45
- - If uncertain, express your uncertainty
46
- - Always provide citations
47
- - If the context doesn't contain enough information, say so
48
-
49
- Answer:"""
50
- ),
51
-
52
- 'rag_uncertain': PromptTemplate(
53
- name='rag_uncertain',
54
- system_prompt="You are a helpful assistant. Express uncertainty when appropriate and always cite sources.",
55
- template="""Context:
56
- {context}
57
-
58
- Question: {question}
59
-
60
- Answer (express uncertainty if appropriate):"""
61
- )
62
- }
63
-
64
- def get_template(self, name: str) -> PromptTemplate:
65
- """Get a prompt template by name"""
66
- if name not in self.templates:
67
- raise ValueError(f"Unknown template: {name}")
68
- return self.templates[name]
69
-
70
- def format_prompt(self, template_name: str, **kwargs) -> str:
71
- """Format a prompt using a template"""
72
- template = self.get_template(template_name)
73
-
74
- # Format the main template
75
- formatted = template.template.format(**kwargs)
76
-
77
- # Add system prompt if available
78
- if template.system_prompt:
79
- formatted = f"{template.system_prompt}\n\n{formatted}"
80
-
81
- return formatted
82
-
83
- def format_context(self, retrieved_passages: List[Dict[str, Any]],
84
- max_length: int = 2000) -> str:
85
- """Format retrieved passages as context"""
86
- context_parts = []
87
- current_length = 0
88
-
89
- for i, passage in enumerate(retrieved_passages):
90
- text = passage.get('text', '')
91
- if current_length + len(text) > max_length:
92
- break
93
-
94
- context_parts.append(f"[{i+1}] {text}")
95
- current_length += len(text)
96
-
97
- return "\n\n".join(context_parts)
98
-
99
- def create_rag_prompt(self, question: str, retrieved_passages: List[Dict[str, Any]],
100
- template_name: str = 'rag', max_context_length: int = 2000) -> str:
101
- """Create a RAG prompt"""
102
- context = self.format_context(retrieved_passages, max_context_length)
103
- return self.format_prompt(template_name, question=question, context=context)
104
-
105
- def create_batch_prompts(self, questions: List[str],
106
- retrieved_passages_list: List[List[Dict[str, Any]]],
107
- template_name: str = 'rag') -> List[str]:
108
- """Create multiple RAG prompts"""
109
- prompts = []
110
- for question, passages in zip(questions, retrieved_passages_list):
111
- prompt = self.create_rag_prompt(question, passages, template_name)
112
- prompts.append(prompt)
113
- return prompts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
generator/safe_generate.py DELETED
@@ -1,170 +0,0 @@
1
- from typing import List, Dict, Any, Optional, Tuple
2
- import logging
3
- from .vllm_server import VLLMServer
4
- from .prompt_templates import PromptTemplates
5
- from ..calibration.features import RiskFeatureExtractor
6
-
7
- logger = logging.getLogger(__name__)
8
-
9
- class SafeGenerator:
10
- def __init__(self, vllm_server: VLLMServer,
11
- risk_extractor: RiskFeatureExtractor,
12
- tau1: float = 0.3, tau2: float = 0.7):
13
- self.vllm_server = vllm_server
14
- self.risk_extractor = risk_extractor
15
- self.prompt_templates = PromptTemplates()
16
- self.tau1 = tau1 # Low risk threshold
17
- self.tau2 = tau2 # High risk threshold
18
-
19
- def generate_with_strategy(self, question: str,
20
- retrieved_passages: List[Dict[str, Any]],
21
- force_citation: bool = False) -> Dict[str, Any]:
22
- """Generate answer with adaptive strategy based on risk assessment"""
23
-
24
- # Extract risk features
25
- risk_features = self.risk_extractor.extract_features(
26
- question, retrieved_passages
27
- )
28
-
29
- # Get risk score (placeholder - will be implemented in calibration module)
30
- risk_score = self._estimate_risk_score(risk_features)
31
-
32
- # Determine strategy based on risk score
33
- if risk_score < self.tau1:
34
- # Low risk: normal generation
35
- strategy = "normal"
36
- temperature = 0.7
37
- template_name = "rag"
38
- elif risk_score < self.tau2:
39
- # Medium risk: conservative generation with citations
40
- strategy = "conservative"
41
- temperature = 0.5
42
- template_name = "rag_with_citations"
43
- force_citation = True
44
- else:
45
- # High risk: very conservative or refuse
46
- strategy = "conservative_or_refuse"
47
- temperature = 0.3
48
- template_name = "rag_safe"
49
- force_citation = True
50
-
51
- # Generate prompt
52
- prompt = self.prompt_templates.create_rag_prompt(
53
- question, retrieved_passages, template_name
54
- )
55
-
56
- # Generate answer
57
- try:
58
- result = self.vllm_server.generate_single(
59
- prompt,
60
- max_tokens=512,
61
- temperature=temperature
62
- )
63
-
64
- # Post-process for citations if needed
65
- if force_citation:
66
- result = self._add_citations(result, retrieved_passages)
67
-
68
- return {
69
- 'answer': result,
70
- 'risk_score': risk_score,
71
- 'strategy': strategy,
72
- 'temperature': temperature,
73
- 'features': risk_features,
74
- 'citations': self._extract_citations(result, retrieved_passages)
75
- }
76
-
77
- except Exception as e:
78
- logger.error(f"Generation failed: {e}")
79
- return {
80
- 'answer': "I apologize, but I encountered an error while generating a response.",
81
- 'risk_score': 1.0,
82
- 'strategy': 'error',
83
- 'temperature': 0.0,
84
- 'features': risk_features,
85
- 'citations': []
86
- }
87
-
88
- def generate_batch(self, questions: List[str],
89
- retrieved_passages_list: List[List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
90
- """Generate answers for multiple questions"""
91
- results = []
92
-
93
- for question, passages in zip(questions, retrieved_passages_list):
94
- result = self.generate_with_strategy(question, passages)
95
- results.append(result)
96
-
97
- return results
98
-
99
- def _estimate_risk_score(self, features: Dict[str, Any]) -> float:
100
- """Estimate risk score from features (placeholder implementation)"""
101
- # This is a simplified risk estimation
102
- # In practice, this would use a trained calibration model
103
-
104
- # Higher similarity scores = lower risk
105
- avg_similarity = features.get('avg_similarity', 0.5)
106
-
107
- # More diverse passages = lower risk
108
- diversity = features.get('diversity', 0.5)
109
-
110
- # More passages = lower risk (up to a point)
111
- num_passages = min(features.get('num_passages', 1), 10)
112
- passage_score = 1.0 - (num_passages / 10.0)
113
-
114
- # Combine factors
115
- risk_score = 1.0 - (avg_similarity * 0.4 + diversity * 0.3 + (1.0 - passage_score) * 0.3)
116
-
117
- return max(0.0, min(1.0, risk_score))
118
-
119
- def _add_citations(self, answer: str, passages: List[Dict[str, Any]]) -> str:
120
- """Add citations to answer if not present"""
121
- if '[' in answer and ']' in answer:
122
- return answer # Already has citations
123
-
124
- # Simple citation addition (in practice, use more sophisticated methods)
125
- cited_answer = answer
126
- for i, passage in enumerate(passages[:3]): # Limit to first 3 passages
127
- if any(word in answer.lower() for word in passage['text'].lower().split()[:5]):
128
- cited_answer += f" [{i+1}]"
129
-
130
- return cited_answer
131
-
132
- def _extract_citations(self, answer: str, passages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
133
- """Extract citations from answer"""
134
- citations = []
135
-
136
- # Find citation markers like [1], [2], etc.
137
- import re
138
- citation_matches = re.findall(r'\[(\d+)\]', answer)
139
-
140
- for match in citation_matches:
141
- idx = int(match) - 1
142
- if 0 <= idx < len(passages):
143
- citations.append({
144
- 'id': idx,
145
- 'text': passages[idx]['text'],
146
- 'metadata': passages[idx].get('metadata', {})
147
- })
148
-
149
- return citations
150
-
151
- def get_generation_stats(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
152
- """Get statistics from generation results"""
153
- if not results:
154
- return {}
155
-
156
- risk_scores = [r['risk_score'] for r in results]
157
- strategies = [r['strategy'] for r in results]
158
-
159
- strategy_counts = {}
160
- for strategy in strategies:
161
- strategy_counts[strategy] = strategy_counts.get(strategy, 0) + 1
162
-
163
- return {
164
- 'num_queries': len(results),
165
- 'avg_risk_score': sum(risk_scores) / len(risk_scores),
166
- 'min_risk_score': min(risk_scores),
167
- 'max_risk_score': max(risk_scores),
168
- 'strategy_distribution': strategy_counts,
169
- 'avg_citations_per_answer': sum(len(r.get('citations', [])) for r in results) / len(results)
170
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
generator/vllm_server.py DELETED
@@ -1,102 +0,0 @@
1
- from vllm import LLM, SamplingParams
2
- from typing import List, Dict, Any, Optional
3
- import logging
4
- import asyncio
5
- from concurrent.futures import ThreadPoolExecutor
6
-
7
- logger = logging.getLogger(__name__)
8
-
9
- class VLLMServer:
10
- def __init__(self, model_name: str = "openai/gpt-oss-20b",
11
- tensor_parallel_size: int = 1, gpu_memory_utilization: float = 0.9):
12
- self.model_name = model_name
13
- self.tensor_parallel_size = tensor_parallel_size
14
- self.gpu_memory_utilization = gpu_memory_utilization
15
- self.llm = None
16
- self.executor = ThreadPoolExecutor(max_workers=4)
17
-
18
- def initialize(self):
19
- """Initialize the vLLM model"""
20
- try:
21
- self.llm = LLM(
22
- model=self.model_name,
23
- tensor_parallel_size=self.tensor_parallel_size,
24
- gpu_memory_utilization=self.gpu_memory_utilization,
25
- trust_remote_code=True
26
- )
27
- logger.info(f"Initialized vLLM with model: {self.model_name}")
28
- except Exception as e:
29
- logger.error(f"Failed to initialize vLLM: {e}")
30
- raise
31
-
32
- def generate(self, prompts: List[str],
33
- max_tokens: int = 512,
34
- temperature: float = 0.7,
35
- top_p: float = 0.9,
36
- stop: Optional[List[str]] = None) -> List[Dict[str, Any]]:
37
- """Generate text for prompts"""
38
- if self.llm is None:
39
- self.initialize()
40
-
41
- sampling_params = SamplingParams(
42
- max_tokens=max_tokens,
43
- temperature=temperature,
44
- top_p=top_p,
45
- stop=stop
46
- )
47
-
48
- try:
49
- outputs = self.llm.generate(prompts, sampling_params)
50
-
51
- results = []
52
- for output in outputs:
53
- results.append({
54
- 'text': output.outputs[0].text,
55
- 'prompt': output.prompt,
56
- 'finish_reason': output.outputs[0].finish_reason,
57
- 'token_ids': output.outputs[0].token_ids,
58
- 'logprobs': getattr(output.outputs[0], 'logprobs', None)
59
- })
60
-
61
- return results
62
- except Exception as e:
63
- logger.error(f"Generation failed: {e}")
64
- raise
65
-
66
- def generate_single(self, prompt: str, **kwargs) -> str:
67
- """Generate text for a single prompt"""
68
- results = self.generate([prompt], **kwargs)
69
- return results[0]['text'] if results else ""
70
-
71
- def generate_batch(self, prompts: List[str], batch_size: int = 8, **kwargs) -> List[str]:
72
- """Generate text for multiple prompts in batches"""
73
- all_results = []
74
-
75
- for i in range(0, len(prompts), batch_size):
76
- batch_prompts = prompts[i:i + batch_size]
77
- batch_results = self.generate(batch_prompts, **kwargs)
78
- all_results.extend([r['text'] for r in batch_results])
79
-
80
- return all_results
81
-
82
- async def generate_async(self, prompts: List[str], **kwargs) -> List[Dict[str, Any]]:
83
- """Async generation"""
84
- loop = asyncio.get_event_loop()
85
- return await loop.run_in_executor(self.executor, self.generate, prompts, **kwargs)
86
-
87
- def get_model_info(self) -> Dict[str, Any]:
88
- """Get model information"""
89
- if self.llm is None:
90
- return {}
91
-
92
- return {
93
- 'model_name': self.model_name,
94
- 'tensor_parallel_size': self.tensor_parallel_size,
95
- 'gpu_memory_utilization': self.gpu_memory_utilization,
96
- 'is_initialized': self.llm is not None
97
- }
98
-
99
- def cleanup(self):
100
- """Cleanup resources"""
101
- if self.executor:
102
- self.executor.shutdown(wait=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
real_embedding_test.py DELETED
@@ -1,269 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- SafeRAG Real Embedding Test
5
- Load data -> Generate real embeddings using sentence-transformers -> Build index -> Retrieve
6
- """
7
-
8
- import sys
9
- import os
10
- import time
11
- import numpy as np
12
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
13
-
14
- def test_real_embedding_pipeline():
15
- """Test the complete pipeline with real embeddings"""
16
- print("SafeRAG Real Embedding Pipeline Test")
17
- print("=" * 50)
18
-
19
- try:
20
- # Step 1: Load data
21
- print("\n1. Loading data...")
22
- from data_processing import DataLoader, Preprocessor
23
-
24
- loader = DataLoader()
25
- preprocessor = Preprocessor()
26
-
27
- # Load knowledge base
28
- kb_passages = loader.get_knowledge_base()
29
- print(f" ✓ Loaded {len(kb_passages)} knowledge base passages")
30
-
31
- # Show sample passages
32
- for i, passage in enumerate(kb_passages):
33
- print(f" [{i+1}] {passage}")
34
-
35
- # Preprocess passages
36
- processed_passages = preprocessor.preprocess_passages(kb_passages)
37
- print(f" ✓ Preprocessed {len(processed_passages)} passages")
38
-
39
- # Step 2: Generate real embeddings
40
- print("\n2. Generating real embeddings with sentence-transformers...")
41
- from retriever import Embedder
42
-
43
- # Use a smaller model for faster testing
44
- embedder = Embedder(model_name="all-MiniLM-L6-v2", device="cpu")
45
- print(f" ✓ Loaded embedding model: {embedder.model_name}")
46
- print(f" ✓ Embedding dimension: {embedder.get_dimension()}")
47
-
48
- # Extract text from processed passages
49
- passage_texts = [p['text'] for p in processed_passages]
50
-
51
- # Generate embeddings
52
- start_time = time.time()
53
- embeddings = embedder.encode_passages(passage_texts)
54
- embedding_time = time.time() - start_time
55
-
56
- print(f" ✓ Generated {embeddings.shape[0]} embeddings in {embedding_time:.3f}s")
57
- print(f" ✓ Embedding shape: {embeddings.shape}")
58
- print(f" ✓ Embedding type: {type(embeddings)}")
59
-
60
- # Show embedding statistics
61
- print(f" ✓ Embedding stats:")
62
- print(f" - Mean: {np.mean(embeddings):.4f}")
63
- print(f" - Std: {np.std(embeddings):.4f}")
64
- print(f" - Min: {np.min(embeddings):.4f}")
65
- print(f" - Max: {np.max(embeddings):.4f}")
66
-
67
- # Step 3: Build FAISS index
68
- print("\n3. Building FAISS index...")
69
- from retriever import FAISSIndex
70
-
71
- index = FAISSIndex(embedder.get_dimension())
72
- start_time = time.time()
73
- index.build_index(embeddings, passage_texts)
74
- build_time = time.time() - start_time
75
-
76
- print(f" ✓ Built FAISS index in {build_time:.3f}s")
77
- print(f" ✓ Index contains {index.index.ntotal} vectors")
78
-
79
- # Step 4: Test retrieval
80
- print("\n4. Testing retrieval...")
81
- from retriever import Retriever
82
-
83
- retriever = Retriever(embedder, index, None) # No reranker for simplicity
84
-
85
- test_queries = [
86
- "What is machine learning?",
87
- "Tell me about the capital of France",
88
- "How does Python work?",
89
- "What is artificial intelligence?"
90
- ]
91
-
92
- for query in test_queries:
93
- print(f"\n Query: '{query}'")
94
- start_time = time.time()
95
- results = retriever.retrieve_single(query, k=3)
96
- retrieval_time = time.time() - start_time
97
-
98
- print(f" ✓ Retrieved {len(results)} passages in {retrieval_time:.3f}s")
99
- for i, result in enumerate(results):
100
- print(f" [{i+1}] Score: {result['score']:.4f}")
101
- print(f" Text: {result['text'][:100]}...")
102
-
103
- # Step 5: Test similarity calculation
104
- print("\n5. Testing similarity calculation...")
105
-
106
- # Test query-passage similarity
107
- query = "What is machine learning?"
108
- query_embedding = embedder.encode_queries([query])[0]
109
-
110
- print(f" Query: '{query}'")
111
- print(f" Query embedding shape: {query_embedding.shape}")
112
-
113
- # Calculate similarities with all passages
114
- similarities = []
115
- for i, passage_embedding in enumerate(embeddings):
116
- # Cosine similarity
117
- similarity = np.dot(query_embedding, passage_embedding) / (
118
- np.linalg.norm(query_embedding) * np.linalg.norm(passage_embedding)
119
- )
120
- similarities.append((i, similarity, passage_texts[i]))
121
-
122
- # Sort by similarity
123
- similarities.sort(key=lambda x: x[1], reverse=True)
124
-
125
- print(f" ✓ Calculated similarities with {len(similarities)} passages")
126
- print(f" Top 3 most similar passages:")
127
- for i, (idx, sim, text) in enumerate(similarities[:3]):
128
- print(f" [{i+1}] Similarity: {sim:.4f}")
129
- print(f" Text: {text[:80]}...")
130
-
131
- # Step 6: Test generation
132
- print("\n6. Testing generation...")
133
- from generator import SafeGenerator, PromptTemplates
134
-
135
- templates = PromptTemplates()
136
- generator = SafeGenerator(None, None, 0.3, 0.7) # Simplified version
137
-
138
- test_query = "What is machine learning?"
139
- retrieved_passages = retriever.retrieve_single(test_query, k=3)
140
-
141
- print(f" Query: '{test_query}'")
142
- print(f" Retrieved {len(retrieved_passages)} passages")
143
-
144
- # Generate answer
145
- start_time = time.time()
146
- result = generator.generate_with_strategy(test_query, retrieved_passages)
147
- generation_time = time.time() - start_time
148
-
149
- print(f" ✓ Generated answer in {generation_time:.3f}s")
150
- print(f" Answer: {result['answer'][:200]}...")
151
- print(f" Risk Score: {result['risk_score']:.3f}")
152
- print(f" Strategy: {result['strategy']}")
153
-
154
- print("\n" + "=" * 50)
155
- print("🎉 Real embedding pipeline test completed successfully!")
156
- print("\nPipeline Summary:")
157
- print(f"- Data Loading: {len(kb_passages)} passages")
158
- print(f"- Real Embedding Generation: {embeddings.shape[0]} vectors ({embeddings.shape[1]}D)")
159
- print(f"- Index Building: {index.index.ntotal} indexed vectors")
160
- print(f"- Retrieval: {len(test_queries)} test queries")
161
- print(f"- Similarity Calculation: Cosine similarity with all passages")
162
- print(f"- Generation: Risk-aware answer generation")
163
-
164
- return True
165
-
166
- except Exception as e:
167
- print(f"\n❌ Pipeline test failed: {e}")
168
- import traceback
169
- traceback.print_exc()
170
- return False
171
-
172
- def test_embedding_quality():
173
- """Test embedding quality and properties"""
174
- print("\n" + "=" * 50)
175
- print("Testing Embedding Quality")
176
- print("=" * 50)
177
-
178
- try:
179
- from retriever import Embedder
180
-
181
- # Initialize embedder
182
- embedder = Embedder(model_name="all-MiniLM-L6-v2", device="cpu")
183
-
184
- # Test texts
185
- test_texts = [
186
- "Machine learning is a subset of artificial intelligence",
187
- "The capital of France is Paris",
188
- "Python is a programming language",
189
- "Machine learning algorithms learn from data", # Similar to first
190
- "Paris is the capital city of France", # Similar to second
191
- ]
192
-
193
- print("1. Generating embeddings for test texts...")
194
- embeddings = embedder.encode(test_texts)
195
- print(f" ✓ Generated {embeddings.shape[0]} embeddings")
196
-
197
- print("\n2. Testing similarity between related texts...")
198
-
199
- # Test similarity between related texts
200
- pairs = [
201
- (0, 3, "Machine learning texts"),
202
- (1, 4, "France/Paris texts"),
203
- ]
204
-
205
- for i, j, description in pairs:
206
- sim = np.dot(embeddings[i], embeddings[j]) / (
207
- np.linalg.norm(embeddings[i]) * np.linalg.norm(embeddings[j])
208
- )
209
- print(f" {description}: {sim:.4f}")
210
- print(f" Text 1: {test_texts[i]}")
211
- print(f" Text 2: {test_texts[j]}")
212
-
213
- print("\n3. Testing embedding properties...")
214
-
215
- # Check if embeddings are normalized
216
- norms = [np.linalg.norm(emb) for emb in embeddings]
217
- print(f" ✓ Embedding norms: {[f'{n:.4f}' for n in norms]}")
218
-
219
- # Check embedding statistics
220
- all_embeddings = embeddings.flatten()
221
- print(f" ✓ All embedding values:")
222
- print(f" - Mean: {np.mean(all_embeddings):.4f}")
223
- print(f" - Std: {np.std(all_embeddings):.4f}")
224
- print(f" - Min: {np.min(all_embeddings):.4f}")
225
- print(f" - Max: {np.max(all_embeddings):.4f}")
226
-
227
- print("\n✅ Embedding quality test completed!")
228
- return True
229
-
230
- except Exception as e:
231
- print(f"\n❌ Embedding quality test failed: {e}")
232
- import traceback
233
- traceback.print_exc()
234
- return False
235
-
236
- def main():
237
- """Run all tests"""
238
- print("SafeRAG Real Embedding Test Suite")
239
- print("=" * 60)
240
-
241
- success = True
242
-
243
- # Test embedding quality
244
- if not test_embedding_quality():
245
- success = False
246
-
247
- # Test real embedding pipeline
248
- if not test_real_embedding_pipeline():
249
- success = False
250
-
251
- print("\n" + "=" * 60)
252
- if success:
253
- print("🎉 All real embedding tests passed!")
254
- print("\nThe system can now:")
255
- print("1. ✅ Load data from knowledge base")
256
- print("2. ✅ Generate real embeddings using sentence-transformers")
257
- print("3. ✅ Build FAISS index with real embeddings")
258
- print("4. ✅ Retrieve relevant passages using real similarity")
259
- print("5. ✅ Calculate cosine similarity between queries and passages")
260
- print("6. ✅ Generate answers based on retrieved passages")
261
- print("7. ✅ Assess embedding quality and properties")
262
- else:
263
- print("❌ Some tests failed. Please check the errors above.")
264
-
265
- return success
266
-
267
- if __name__ == "__main__":
268
- success = main()
269
- sys.exit(0 if success else 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt DELETED
@@ -1,19 +0,0 @@
1
- torch>=2.0.0
2
- transformers>=4.35.0
3
- datasets>=2.14.0
4
- vllm>=0.2.0
5
- faiss-cpu>=1.7.4
6
- sentence-transformers>=2.2.2
7
- scikit-learn>=1.3.0
8
- numpy>=1.24.0
9
- pandas>=2.0.0
10
- tqdm>=4.65.0
11
- gradio>=4.0.0
12
- accelerate>=0.24.0
13
- evaluate>=0.4.0
14
- rouge-score>=0.1.2
15
- nltk>=3.8.0
16
- spacy>=3.7.0
17
- matplotlib>=3.7.0
18
- seaborn>=0.12.0
19
- wandb>=0.15.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
retriever/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- from .embedder import Embedder
2
- from .faiss_index import FAISSIndex
3
- from .retriever import Retriever
4
- from .reranker import Reranker
5
-
6
- __all__ = ['Embedder', 'FAISSIndex', 'Retriever', 'Reranker']
 
 
 
 
 
 
 
retriever/embedder.py DELETED
@@ -1,49 +0,0 @@
1
- from sentence_transformers import SentenceTransformer
2
- from typing import List, Union
3
- import numpy as np
4
- import logging
5
-
6
- logger = logging.getLogger(__name__)
7
-
8
- class Embedder:
9
- def __init__(self, model_name: str = "BAAI/bge-large-en-v1.5", device: str = "cuda"):
10
- self.model_name = model_name
11
- self.device = device
12
- self.model = SentenceTransformer(model_name, device=device)
13
- logger.info(f"Loaded embedding model: {model_name}")
14
-
15
- def encode(self, texts: Union[str, List[str]], batch_size: int = 16) -> np.ndarray:
16
- """Encode texts to embeddings"""
17
- if isinstance(texts, str):
18
- texts = [texts]
19
-
20
- embeddings = self.model.encode(
21
- texts,
22
- batch_size=batch_size,
23
- convert_to_numpy=True,
24
- show_progress_bar=len(texts) > 100
25
- )
26
-
27
- return embeddings
28
-
29
- def encode_queries(self, queries: List[str], batch_size: int = 16) -> np.ndarray:
30
- """Encode queries with query prefix"""
31
- if not queries:
32
- return np.array([])
33
-
34
- # Add query prefix for BGE models
35
- prefixed_queries = [f"Represent this sentence for searching relevant passages: {q}" for q in queries]
36
- return self.encode(prefixed_queries, batch_size)
37
-
38
- def encode_passages(self, passages: List[str], batch_size: int = 16) -> np.ndarray:
39
- """Encode passages with passage prefix"""
40
- if not passages:
41
- return np.array([])
42
-
43
- # Add passage prefix for BGE models
44
- prefixed_passages = [f"Represent this sentence for searching relevant passages: {p}" for p in passages]
45
- return self.encode(prefixed_passages, batch_size)
46
-
47
- def get_dimension(self) -> int:
48
- """Get embedding dimension"""
49
- return self.model.get_sentence_embedding_dimension()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
retriever/faiss_index.py DELETED
@@ -1,131 +0,0 @@
1
- # 工厂函数,供pipeline调用
2
- def build_faiss_index(embeddings, texts, metadata=None, index_type="HNSW"):
3
- if embeddings is None or not hasattr(embeddings, 'shape') or len(embeddings.shape) != 2 or embeddings.shape[0] == 0:
4
- raise ValueError(f"Embeddings is empty or not a 2D array. Got shape: {getattr(embeddings, 'shape', None)}")
5
- dimension = embeddings.shape[1]
6
- index = FAISSIndex(dimension, index_type=index_type)
7
- index.build_index(embeddings, texts, metadata)
8
- return index
9
- import faiss
10
- import numpy as np
11
- import pickle
12
- import os
13
- from typing import List, Dict, Any, Tuple
14
- import logging
15
-
16
- logger = logging.getLogger(__name__)
17
-
18
- class FAISSIndex:
19
- def __init__(self, dimension: int, index_type: str = "HNSW"):
20
- self.dimension = dimension
21
- self.index_type = index_type
22
- self.index = None
23
- self.id_to_text = {}
24
- self.id_to_metadata = {}
25
- self.next_id = 0
26
-
27
- def build_index(self, embeddings: np.ndarray, texts: List[str],
28
- metadata: List[Dict[str, Any]] = None) -> None:
29
- """Build FAISS index from embeddings"""
30
- if embeddings.shape[1] != self.dimension:
31
- raise ValueError(f"Embedding dimension {embeddings.shape[1]} != {self.dimension}")
32
- # Normalize embeddings for cosine similarity
33
- faiss.normalize_L2(embeddings)
34
- if self.index_type == "HNSW":
35
- # HNSW index for fast approximate search
36
- self.index = faiss.IndexHNSWFlat(self.dimension, 32) # 32 is default M
37
- self.index.hnsw.efConstruction = 200
38
- self.index.add(embeddings)
39
- elif self.index_type == "IVF":
40
- nlist = min(4096, len(embeddings) // 100)
41
- quantizer = faiss.IndexFlatIP(self.dimension)
42
- self.index = faiss.IndexIVFFlat(quantizer, self.dimension, nlist)
43
- self.index.train(embeddings)
44
- self.index.add(embeddings)
45
- else:
46
- self.index = faiss.IndexFlatIP(self.dimension)
47
- self.index.add(embeddings)
48
- # Store text and metadata
49
- for i, text in enumerate(texts):
50
- self.id_to_text[i] = text
51
- if metadata and i < len(metadata):
52
- self.id_to_metadata[i] = metadata[i]
53
- logger.info(f"Built FAISS {self.index_type} index with {len(embeddings)} vectors")
54
-
55
- def search(self, query_embeddings: np.ndarray, k: int = 10) -> Tuple[np.ndarray, np.ndarray]:
56
- """Search for similar vectors"""
57
- if self.index is None:
58
- raise ValueError("Index not built yet")
59
-
60
- # Normalize query embeddings
61
- faiss.normalize_L2(query_embeddings)
62
-
63
- # Search
64
- scores, indices = self.index.search(query_embeddings, k)
65
-
66
- return scores, indices
67
-
68
- def get_texts(self, indices: np.ndarray) -> List[str]:
69
- """Get texts by indices"""
70
- texts = []
71
- for idx in indices.flatten():
72
- if idx in self.id_to_text:
73
- texts.append(self.id_to_text[idx])
74
- else:
75
- texts.append("")
76
- return texts
77
-
78
- def get_metadata(self, indices: np.ndarray) -> List[Dict[str, Any]]:
79
- """Get metadata by indices"""
80
- metadata = []
81
- for idx in indices.flatten():
82
- if idx in self.id_to_metadata:
83
- metadata.append(self.id_to_metadata[idx])
84
- else:
85
- metadata.append({})
86
- return metadata
87
-
88
- def save(self, path: str) -> None:
89
- """Save index to disk"""
90
- os.makedirs(os.path.dirname(path), exist_ok=True)
91
-
92
- # Save FAISS index
93
- faiss.write_index(self.index, f"{path}.faiss")
94
-
95
- # Save metadata
96
- with open(f"{path}.pkl", "wb") as f:
97
- pickle.dump({
98
- 'id_to_text': self.id_to_text,
99
- 'id_to_metadata': self.id_to_metadata,
100
- 'dimension': self.dimension,
101
- 'index_type': self.index_type
102
- }, f)
103
-
104
- logger.info(f"Saved index to {path}")
105
-
106
- def load(self, path: str) -> None:
107
- """Load index from disk"""
108
- # Load FAISS index
109
- self.index = faiss.read_index(f"{path}.faiss")
110
-
111
- # Load metadata
112
- with open(f"{path}.pkl", "rb") as f:
113
- data = pickle.load(f)
114
- self.id_to_text = data['id_to_text']
115
- self.id_to_metadata = data['id_to_metadata']
116
- self.dimension = data['dimension']
117
- self.index_type = data['index_type']
118
-
119
- logger.info(f"Loaded index from {path}")
120
-
121
- def get_stats(self) -> Dict[str, Any]:
122
- """Get index statistics"""
123
- if self.index is None:
124
- return {}
125
-
126
- return {
127
- 'num_vectors': self.index.ntotal,
128
- 'dimension': self.dimension,
129
- 'index_type': self.index_type,
130
- 'is_trained': self.index.is_trained if hasattr(self.index, 'is_trained') else True
131
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
retriever/reranker.py DELETED
@@ -1,46 +0,0 @@
1
- from sentence_transformers import CrossEncoder
2
- from typing import List
3
- import numpy as np
4
- import logging
5
-
6
- logger = logging.getLogger(__name__)
7
-
8
- class Reranker:
9
- def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", device: str = "cuda"):
10
- self.model_name = model_name
11
- self.device = device
12
- self.model = CrossEncoder(model_name, device=device)
13
- logger.info(f"Loaded reranker model: {model_name}")
14
-
15
- def rerank(self, query: str, passages: List[str], batch_size: int = 32) -> List[float]:
16
- """Rerank passages for a query"""
17
- if not passages:
18
- return []
19
-
20
- # Create query-passage pairs
21
- pairs = [(query, passage) for passage in passages]
22
-
23
- # Get relevance scores
24
- scores = self.model.predict(pairs, batch_size=batch_size)
25
-
26
- return scores.tolist()
27
-
28
- def rerank_batch(self, queries: List[str], passages_list: List[List[str]],
29
- batch_size: int = 32) -> List[List[float]]:
30
- """Rerank passages for multiple queries"""
31
- all_scores = []
32
-
33
- for query, passages in zip(queries, passages_list):
34
- scores = self.rerank(query, passages, batch_size)
35
- all_scores.append(scores)
36
-
37
- return all_scores
38
-
39
- def get_top_k(self, query: str, passages: List[str], k: int = 5) -> List[tuple]:
40
- """Get top-k passages with scores"""
41
- scores = self.rerank(query, passages)
42
-
43
- # Sort by score
44
- ranked = sorted(zip(passages, scores), key=lambda x: x[1], reverse=True)
45
-
46
- return ranked[:k]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
retriever/retriever.py DELETED
@@ -1,104 +0,0 @@
1
- from typing import List, Dict, Any, Tuple
2
- import numpy as np
3
- from .embedder import Embedder
4
- from .faiss_index import FAISSIndex
5
- from .reranker import Reranker
6
- import logging
7
-
8
- logger = logging.getLogger(__name__)
9
-
10
- class Retriever:
11
- def __init__(self, embedder: Embedder, index: FAISSIndex, reranker: Reranker = None):
12
- self.embedder = embedder
13
- self.index = index
14
- self.reranker = reranker
15
-
16
- def retrieve(self, queries: List[str], k: int = 20,
17
- rerank_k: int = 10) -> List[List[Dict[str, Any]]]:
18
- """Retrieve and rerank passages for queries"""
19
- if not queries:
20
- return []
21
-
22
- # Encode queries
23
- query_embeddings = self.embedder.encode_queries(queries)
24
-
25
- # Search index
26
- scores, indices = self.index.search(query_embeddings, k)
27
-
28
- # Format results
29
- results = []
30
- for i, query in enumerate(queries):
31
- query_results = []
32
- for j, (score, idx) in enumerate(zip(scores[i], indices[i])):
33
- if idx == -1: # Invalid index
34
- continue
35
-
36
- text = self.index.id_to_text.get(idx, "")
37
- metadata = self.index.id_to_metadata.get(idx, {})
38
-
39
- query_results.append({
40
- 'text': text,
41
- 'score': float(score),
42
- 'rank': j + 1,
43
- 'metadata': metadata,
44
- 'id': idx
45
- })
46
-
47
- results.append(query_results)
48
-
49
- # Rerank if reranker is available
50
- if self.reranker and rerank_k < k:
51
- reranked_results = []
52
- for i, query in enumerate(queries):
53
- passages = [r['text'] for r in results[i][:k]]
54
- rerank_scores = self.reranker.rerank(query, passages)
55
-
56
- # Reorder results based on rerank scores
57
- reranked = sorted(
58
- zip(results[i][:k], rerank_scores),
59
- key=lambda x: x[1],
60
- reverse=True
61
- )
62
-
63
- reranked_results.append([
64
- {**result, 'rerank_score': score, 'rank': j + 1}
65
- for j, (result, score) in enumerate(reranked[:rerank_k])
66
- ])
67
-
68
- results = reranked_results
69
-
70
- return results
71
-
72
- def retrieve_single(self, query: str, k: int = 10) -> List[Dict[str, Any]]:
73
- """Retrieve for a single query"""
74
- results = self.retrieve([query], k)
75
- return results[0] if results else []
76
-
77
- def batch_retrieve(self, queries: List[str], batch_size: int = 32,
78
- k: int = 10) -> List[List[Dict[str, Any]]]:
79
- """Retrieve for multiple queries in batches"""
80
- all_results = []
81
-
82
- for i in range(0, len(queries), batch_size):
83
- batch_queries = queries[i:i + batch_size]
84
- batch_results = self.retrieve(batch_queries, k)
85
- all_results.extend(batch_results)
86
-
87
- return all_results
88
-
89
- def get_retrieval_stats(self, queries: List[str], k: int = 10) -> Dict[str, Any]:
90
- """Get retrieval statistics"""
91
- results = self.retrieve(queries, k)
92
-
93
- scores = []
94
- for query_results in results:
95
- scores.extend([r['score'] for r in query_results])
96
-
97
- return {
98
- 'num_queries': len(queries),
99
- 'avg_scores': np.mean(scores) if scores else 0,
100
- 'std_scores': np.std(scores) if scores else 0,
101
- 'min_scores': np.min(scores) if scores else 0,
102
- 'max_scores': np.max(scores) if scores else 0,
103
- 'avg_results_per_query': np.mean([len(r) for r in results])
104
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
simple_e2e_test.py DELETED
@@ -1,518 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- SafeRAG Simple End-to-End Test
5
- Complete workflow test without external dependencies
6
- """
7
-
8
- import sys
9
- import os
10
- import time
11
- import random
12
- import math
13
-
14
- # Add project root to path
15
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
16
-
17
- def test_basic_functionality():
18
- """Test basic Python functionality"""
19
- print("Testing basic functionality...")
20
-
21
- try:
22
- # Test basic operations
23
- assert 1 + 1 == 2, "Basic math failed"
24
- assert "hello" + " " + "world" == "hello world", "String concatenation failed"
25
- assert len([1, 2, 3]) == 3, "List length failed"
26
- print("+ Basic Python operations work")
27
-
28
- # Test random number generation
29
- random.seed(42)
30
- rand_num = random.random()
31
- assert 0 <= rand_num <= 1, "Random number out of range"
32
- print("+ Random number generation works")
33
-
34
- return True
35
- except Exception as e:
36
- print("✗ Basic functionality test failed:", e)
37
- return False
38
-
39
- def test_text_processing():
40
- """Test text processing functionality"""
41
- print("\nTesting text processing...")
42
-
43
- try:
44
- # Simple text cleaning
45
- def clean_text(text):
46
- if not text:
47
- return ""
48
- # Remove extra whitespace
49
- import re
50
- text = re.sub(r'\s+', ' ', text)
51
- # Remove special characters but keep punctuation
52
- text = re.sub(r'[^\w\s\.\,\!\?\;\:\-\(\)]', '', text)
53
- return text.strip()
54
-
55
- # Test text cleaning
56
- test_text = " This is a test text!!! "
57
- cleaned = clean_text(test_text)
58
- expected = "This is a test text!!!"
59
- assert cleaned == expected, "Text cleaning failed: got '{}', expected '{}'".format(cleaned, expected)
60
- print("+ Text cleaning works")
61
-
62
- # Test sentence extraction
63
- def extract_sentences(text):
64
- sentences = text.split('.')
65
- return [clean_text(s) for s in sentences if s.strip()]
66
-
67
- test_text = "First sentence. Second sentence. Third sentence."
68
- sentences = extract_sentences(test_text)
69
- assert len(sentences) == 3, "Sentence extraction failed: got {} sentences, expected 3".format(len(sentences))
70
- print("+ Sentence extraction works")
71
-
72
- return True
73
- except Exception as e:
74
- print("✗ Text processing test failed:", e)
75
- return False
76
-
77
- def test_simple_embeddings():
78
- """Test simple embedding simulation"""
79
- print("\nTesting simple embeddings...")
80
-
81
- try:
82
- # Simple embedding simulation using random numbers
83
- def create_simple_embeddings(texts, dim=10):
84
- """Create simple random embeddings for testing"""
85
- random.seed(42) # For reproducibility
86
- embeddings = []
87
- for text in texts:
88
- embedding = [random.random() for _ in range(dim)]
89
- # Simple normalization
90
- norm = math.sqrt(sum(x*x for x in embedding))
91
- if norm > 0:
92
- embedding = [x/norm for x in embedding]
93
- embeddings.append(embedding)
94
- return embeddings
95
-
96
- # Test embedding creation
97
- texts = ["This is a test", "Another test sentence"]
98
- embeddings = create_simple_embeddings(texts)
99
- assert len(embeddings) == 2, "Wrong number of embeddings"
100
- assert len(embeddings[0]) == 10, "Wrong embedding dimension"
101
- print("+ Simple embedding creation works")
102
-
103
- # Test similarity calculation
104
- def cosine_similarity(a, b):
105
- dot_product = sum(x * y for x, y in zip(a, b))
106
- norm_a = math.sqrt(sum(x*x for x in a))
107
- norm_b = math.sqrt(sum(x*x for x in b))
108
- if norm_a == 0 or norm_b == 0:
109
- return 0
110
- return dot_product / (norm_a * norm_b)
111
-
112
- sim = cosine_similarity(embeddings[0], embeddings[1])
113
- assert 0 <= sim <= 1, "Similarity score out of range: {}".format(sim)
114
- print("+ Similarity calculation works")
115
-
116
- return True
117
- except Exception as e:
118
- print("✗ Simple embeddings test failed:", e)
119
- return False
120
-
121
- def test_simple_retrieval():
122
- """Test simple retrieval functionality"""
123
- print("\nTesting simple retrieval...")
124
-
125
- try:
126
- # Simple retrieval simulation
127
- class SimpleRetriever:
128
- def __init__(self, passages, embeddings):
129
- self.passages = passages
130
- self.embeddings = embeddings
131
-
132
- def search(self, query_embedding, k=5):
133
- # Calculate similarities
134
- similarities = []
135
- for embedding in self.embeddings:
136
- sim = sum(x * y for x, y in zip(embedding, query_embedding))
137
- similarities.append(sim)
138
-
139
- # Get top-k indices
140
- indexed_sims = [(i, sim) for i, sim in enumerate(similarities)]
141
- indexed_sims.sort(key=lambda x: x[1], reverse=True)
142
- top_indices = [i for i, _ in indexed_sims[:k]]
143
-
144
- # Return results
145
- results = []
146
- for i, idx in enumerate(top_indices):
147
- results.append({
148
- 'text': self.passages[idx],
149
- 'score': similarities[idx],
150
- 'rank': i + 1
151
- })
152
- return results
153
-
154
- # Create test data
155
- passages = [
156
- "Machine learning is a subset of artificial intelligence.",
157
- "Deep learning uses neural networks with multiple layers.",
158
- "Natural language processing deals with text and speech.",
159
- "Computer vision focuses on image and video analysis."
160
- ]
161
-
162
- # Create simple embeddings
163
- def create_simple_embeddings(texts, dim=10):
164
- random.seed(42)
165
- embeddings = []
166
- for text in texts:
167
- embedding = [random.random() for _ in range(dim)]
168
- norm = math.sqrt(sum(x*x for x in embedding))
169
- if norm > 0:
170
- embedding = [x/norm for x in embedding]
171
- embeddings.append(embedding)
172
- return embeddings
173
-
174
- embeddings = create_simple_embeddings(passages)
175
-
176
- # Test retrieval
177
- retriever = SimpleRetriever(passages, embeddings)
178
- query_embedding = [random.random() for _ in range(10)]
179
- norm = math.sqrt(sum(x*x for x in query_embedding))
180
- if norm > 0:
181
- query_embedding = [x/norm for x in query_embedding]
182
-
183
- results = retriever.search(query_embedding, k=3)
184
- assert len(results) == 3, "Retrieval returned wrong number of results: {}".format(len(results))
185
- assert all('text' in r and 'score' in r for r in results), "Retrieval results missing fields"
186
- print("+ Simple retrieval works")
187
-
188
- return True
189
- except Exception as e:
190
- print("✗ Simple retrieval test failed:", e)
191
- return False
192
-
193
- def test_risk_calibration():
194
- """Test risk calibration functionality"""
195
- print("\nTesting risk calibration...")
196
-
197
- try:
198
- # Simple risk feature extraction
199
- def extract_risk_features(question, retrieved_passages):
200
- features = {}
201
-
202
- if not retrieved_passages:
203
- return {'num_passages': 0, 'avg_similarity': 0.0, 'diversity': 0.0}
204
-
205
- # Basic features
206
- features['num_passages'] = len(retrieved_passages)
207
- scores = [p['score'] for p in retrieved_passages]
208
- features['avg_similarity'] = sum(scores) / len(scores)
209
- features['max_similarity'] = max(scores)
210
- features['min_similarity'] = min(scores)
211
-
212
- # Simple diversity calculation
213
- if len(scores) > 1:
214
- mean_score = features['avg_similarity']
215
- variance = sum((x - mean_score) ** 2 for x in scores) / len(scores)
216
- features['diversity'] = 1.0 - math.sqrt(variance)
217
- else:
218
- features['diversity'] = 1.0
219
-
220
- return features
221
-
222
- # Simple risk prediction
223
- def predict_risk(features):
224
- # Simple heuristic for risk scoring
225
- risk_score = 0.0
226
-
227
- # Few passages = higher risk
228
- if features['num_passages'] < 3:
229
- risk_score += 0.3
230
-
231
- # Low similarity = higher risk
232
- if features['avg_similarity'] < 0.5:
233
- risk_score += 0.2
234
-
235
- # Low diversity = higher risk
236
- if features['diversity'] < 0.3:
237
- risk_score += 0.2
238
-
239
- return min(1.0, risk_score)
240
-
241
- # Test risk feature extraction
242
- question = "What is machine learning?"
243
- passages = [
244
- {'text': 'ML is AI subset', 'score': 0.8},
245
- {'text': 'Neural networks are used', 'score': 0.7},
246
- {'text': 'Deep learning is popular', 'score': 0.6}
247
- ]
248
-
249
- features = extract_risk_features(question, passages)
250
- assert 'num_passages' in features, "Missing num_passages feature"
251
- assert features['num_passages'] == 3, "Wrong number of passages: {}".format(features['num_passages'])
252
- print("+ Risk feature extraction works")
253
-
254
- # Test risk prediction
255
- risk_score = predict_risk(features)
256
- assert 0 <= risk_score <= 1, "Risk score out of range: {}".format(risk_score)
257
- print("+ Risk prediction works")
258
-
259
- return True
260
- except Exception as e:
261
- print("✗ Risk calibration test failed:", e)
262
- return False
263
-
264
- def test_generation():
265
- """Test generation functionality"""
266
- print("\nTesting generation...")
267
-
268
- try:
269
- # Simple generation simulation
270
- def generate_answer(question, retrieved_passages, risk_score):
271
- # Simple template-based generation
272
- context = " ".join([p['text'] for p in retrieved_passages[:3]])
273
-
274
- if risk_score < 0.3:
275
- # Low risk: confident answer
276
- answer = "Based on the information: {}. The answer is: {}.".format(
277
- context, "This is a confident answer."
278
- )
279
- elif risk_score < 0.7:
280
- # Medium risk: cautious answer
281
- answer = "Based on the available information: {}. The answer might be: {}.".format(
282
- context, "This is a cautious answer."
283
- )
284
- else:
285
- # High risk: uncertain answer
286
- answer = "The available information: {} is limited. I'm not certain, but it might be: {}.".format(
287
- context, "This is an uncertain answer."
288
- )
289
-
290
- return answer
291
-
292
- # Test generation
293
- question = "What is machine learning?"
294
- passages = [
295
- {'text': 'Machine learning is AI subset', 'score': 0.8},
296
- {'text': 'It uses algorithms', 'score': 0.7}
297
- ]
298
-
299
- # Test different risk levels
300
- for risk_score in [0.2, 0.5, 0.8]:
301
- answer = generate_answer(question, passages, risk_score)
302
- assert len(answer) > 0, "Empty answer generated"
303
- assert "machine learning" in answer.lower() or "ai" in answer.lower(), "Answer doesn't address question"
304
-
305
- print("+ Generation works")
306
-
307
- return True
308
- except Exception as e:
309
- print("✗ Generation test failed:", e)
310
- return False
311
-
312
- def test_evaluation():
313
- """Test evaluation functionality"""
314
- print("\nTesting evaluation...")
315
-
316
- try:
317
- # Simple evaluation metrics
318
- def exact_match(prediction, reference):
319
- return prediction.lower().strip() == reference.lower().strip()
320
-
321
- def f1_score(prediction, reference):
322
- pred_words = set(prediction.lower().split())
323
- ref_words = set(reference.lower().split())
324
-
325
- if len(ref_words) == 0:
326
- return 1.0 if len(pred_words) == 0 else 0.0
327
-
328
- common = pred_words & ref_words
329
- precision = len(common) / len(pred_words) if pred_words else 0.0
330
- recall = len(common) / len(ref_words)
331
-
332
- if precision + recall == 0:
333
- return 0.0
334
-
335
- return 2 * precision * recall / (precision + recall)
336
-
337
- # Test evaluation
338
- predictions = ["Machine learning is AI", "Deep learning uses neural networks"]
339
- references = ["Machine learning is AI", "Deep learning uses neural networks"]
340
-
341
- # Test exact match
342
- em_scores = [exact_match(p, r) for p, r in zip(predictions, references)]
343
- assert all(em_scores), "Exact match failed"
344
- print("+ Exact match evaluation works")
345
-
346
- # Test F1 score
347
- f1_scores = [f1_score(p, r) for p, r in zip(predictions, references)]
348
- assert all(0 <= score <= 1 for score in f1_scores), "F1 scores out of range"
349
- print("+ F1 score evaluation works")
350
-
351
- return True
352
- except Exception as e:
353
- print("✗ Evaluation test failed:", e)
354
- return False
355
-
356
- def test_end_to_end_workflow():
357
- """Test complete end-to-end workflow"""
358
- print("\nTesting end-to-end workflow...")
359
-
360
- try:
361
- # Simulate complete RAG pipeline
362
- def rag_pipeline(question):
363
- # Step 1: Create simple embeddings
364
- passages = [
365
- "Machine learning is a subset of artificial intelligence.",
366
- "Deep learning uses neural networks with multiple layers.",
367
- "Natural language processing deals with text and speech.",
368
- "Computer vision focuses on image and video analysis."
369
- ]
370
-
371
- # Simulate embeddings
372
- random.seed(42)
373
- embeddings = []
374
- for passage in passages:
375
- embedding = [random.random() for _ in range(10)]
376
- norm = math.sqrt(sum(x*x for x in embedding))
377
- if norm > 0:
378
- embedding = [x/norm for x in embedding]
379
- embeddings.append(embedding)
380
-
381
- # Step 2: Retrieve relevant passages
382
- query_embedding = [random.random() for _ in range(10)]
383
- norm = math.sqrt(sum(x*x for x in query_embedding))
384
- if norm > 0:
385
- query_embedding = [x/norm for x in query_embedding]
386
-
387
- similarities = []
388
- for embedding in embeddings:
389
- sim = sum(x * y for x, y in zip(embedding, query_embedding))
390
- similarities.append(sim)
391
-
392
- indexed_sims = [(i, sim) for i, sim in enumerate(similarities)]
393
- indexed_sims.sort(key=lambda x: x[1], reverse=True)
394
- top_indices = [i for i, _ in indexed_sims[:3]]
395
-
396
- retrieved_passages = []
397
- for i, idx in enumerate(top_indices):
398
- retrieved_passages.append({
399
- 'text': passages[idx],
400
- 'score': similarities[idx],
401
- 'rank': i + 1
402
- })
403
-
404
- # Step 3: Extract risk features
405
- scores = [p['score'] for p in retrieved_passages]
406
- features = {
407
- 'num_passages': len(retrieved_passages),
408
- 'avg_similarity': sum(scores) / len(scores) if scores else 0.0,
409
- 'diversity': 1.0 - math.sqrt(sum((x - sum(scores)/len(scores))**2 for x in scores) / len(scores)) if len(scores) > 1 else 1.0
410
- }
411
-
412
- # Step 4: Predict risk
413
- risk_score = 0.0
414
- if features['num_passages'] < 3:
415
- risk_score += 0.3
416
- if features['avg_similarity'] < 0.5:
417
- risk_score += 0.2
418
- if features['diversity'] < 0.3:
419
- risk_score += 0.2
420
- risk_score = min(1.0, risk_score)
421
-
422
- # Step 5: Generate answer
423
- context = " ".join([p['text'] for p in retrieved_passages[:3]])
424
- if risk_score < 0.3:
425
- answer = "Based on the information: {}. The answer is: Machine learning is a subset of AI.".format(context)
426
- elif risk_score < 0.7:
427
- answer = "Based on the available information: {}. The answer might be: Machine learning is likely a subset of AI.".format(context)
428
- else:
429
- answer = "The available information: {} is limited. I'm not certain, but it might be: Machine learning could be related to AI.".format(context)
430
-
431
- return {
432
- 'question': question,
433
- 'answer': answer,
434
- 'retrieved_passages': retrieved_passages,
435
- 'risk_score': risk_score,
436
- 'features': features
437
- }
438
-
439
- # Test complete pipeline
440
- question = "What is machine learning?"
441
- result = rag_pipeline(question)
442
-
443
- # Validate result
444
- assert 'question' in result, "Missing question in result"
445
- assert 'answer' in result, "Missing answer in result"
446
- assert 'retrieved_passages' in result, "Missing retrieved passages"
447
- assert 'risk_score' in result, "Missing risk score"
448
- assert 'features' in result, "Missing features"
449
-
450
- assert result['question'] == question, "Question not preserved"
451
- assert len(result['answer']) > 0, "Empty answer"
452
- assert len(result['retrieved_passages']) > 0, "No retrieved passages"
453
- assert 0 <= result['risk_score'] <= 1, "Risk score out of range: {}".format(result['risk_score'])
454
-
455
- print("+ End-to-end workflow works")
456
- print(" Question: {}".format(result['question']))
457
- print(" Answer: {}".format(result['answer'][:100] + "..."))
458
- print(" Risk Score: {:.3f}".format(result['risk_score']))
459
- print(" Retrieved Passages: {}".format(len(result['retrieved_passages'])))
460
-
461
- return True
462
- except Exception as e:
463
- print("✗ End-to-end workflow test failed:", e)
464
- return False
465
-
466
- def main():
467
- """Run all end-to-end tests"""
468
- print("SafeRAG Simple End-to-End Test Suite")
469
- print("=" * 50)
470
-
471
- start_time = time.time()
472
-
473
- tests = [
474
- test_basic_functionality,
475
- test_text_processing,
476
- test_simple_embeddings,
477
- test_simple_retrieval,
478
- test_risk_calibration,
479
- test_generation,
480
- test_evaluation,
481
- test_end_to_end_workflow
482
- ]
483
-
484
- passed = 0
485
- total = len(tests)
486
-
487
- for test in tests:
488
- try:
489
- if test():
490
- passed += 1
491
- except Exception as e:
492
- print("✗ Test {} failed with exception: {}".format(test.__name__, e))
493
-
494
- end_time = time.time()
495
-
496
- print("\n" + "=" * 50)
497
- print("Test Results:")
498
- print("Passed: {}/{}".format(passed, total))
499
- print("Time: {:.2f} seconds".format(end_time - start_time))
500
-
501
- if passed == total:
502
- print("✓ All tests passed! SafeRAG end-to-end workflow is working.")
503
- print("\nThe system can:")
504
- print("- Process text and extract sentences")
505
- print("- Create simple embeddings and calculate similarities")
506
- print("- Retrieve relevant passages based on similarity")
507
- print("- Extract risk features and predict risk scores")
508
- print("- Generate answers with different risk-aware strategies")
509
- print("- Evaluate answers using standard metrics")
510
- print("- Run complete end-to-end RAG pipeline")
511
- return True
512
- else:
513
- print("✗ Some tests failed. Please check the errors above.")
514
- return False
515
-
516
- if __name__ == "__main__":
517
- success = main()
518
- sys.exit(0 if success else 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
simple_test.py DELETED
@@ -1,167 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- Simple SafeRAG Test
5
- Basic functionality test without complex dependencies
6
- """
7
-
8
- import sys
9
- import os
10
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
11
-
12
- def test_imports():
13
- """Test that all modules can be imported"""
14
- print("Testing imports...")
15
-
16
- try:
17
- from data_processing import DataLoader, Preprocessor
18
- print("+ DataLoader and Preprocessor imported successfully")
19
- except Exception as e:
20
- print("✗ Failed to import DataLoader/Preprocessor:", e)
21
- return False
22
-
23
- try:
24
- from retriever import Embedder, FAISSIndex, Retriever, Reranker
25
- print("+ Retriever modules imported successfully")
26
- except Exception as e:
27
- print("✗ Failed to import retriever modules:", e)
28
- return False
29
-
30
- try:
31
- from generator import VLLMServer, SafeGenerator, PromptTemplates
32
- print("+ Generator modules imported successfully")
33
- except Exception as e:
34
- print("✗ Failed to import generator modules:", e)
35
- return False
36
-
37
- try:
38
- from calibration import RiskFeatureExtractor, CalibrationHead
39
- print("+ Calibration modules imported successfully")
40
- except Exception as e:
41
- print("✗ Failed to import calibration modules:", e)
42
- return False
43
-
44
- try:
45
- from eval import QAEvaluator, AttributionEvaluator, CalibrationEvaluator
46
- print("+ Evaluation modules imported successfully")
47
- except Exception as e:
48
- print("✗ Failed to import evaluation modules:", e)
49
- return False
50
-
51
- return True
52
-
53
- def test_basic_functionality():
54
- """Test basic functionality without heavy dependencies"""
55
- print("\nTesting basic functionality...")
56
-
57
- try:
58
- # Test Preprocessor
59
- from data_processing.preprocessor import Preprocessor
60
- preprocessor = Preprocessor()
61
-
62
- # Test text cleaning
63
- text = " This is a test text. "
64
- cleaned = preprocessor.clean_text(text)
65
- assert cleaned == "This is a test text.", "Expected 'This is a test text.', got '{}'".format(cleaned)
66
- print("+ Text cleaning works")
67
-
68
- # Test sentence extraction
69
- text = "First sentence. Second sentence. Third sentence."
70
- sentences = preprocessor.extract_sentences(text)
71
- assert len(sentences) == 3, "Expected 3 sentences, got {}".format(len(sentences))
72
- print("+ Sentence extraction works")
73
-
74
- except Exception as e:
75
- print("✗ Preprocessor test failed:", e)
76
- return False
77
-
78
- try:
79
- # Test PromptTemplates
80
- from generator.prompt_templates import PromptTemplates
81
- templates = PromptTemplates()
82
-
83
- # Test prompt formatting
84
- prompt = templates.format_prompt(
85
- 'rag',
86
- question="What is AI?",
87
- context="AI is artificial intelligence."
88
- )
89
- assert "What is AI?" in prompt, "Question not found in prompt"
90
- assert "AI is artificial intelligence." in prompt, "Context not found in prompt"
91
- print("+ Prompt templates work")
92
-
93
- except Exception as e:
94
- print("✗ PromptTemplates test failed:", e)
95
- return False
96
-
97
- try:
98
- # Test QAEvaluator
99
- from eval.eval_qa import QAEvaluator
100
- evaluator = QAEvaluator()
101
-
102
- # Test exact match
103
- predictions = ["Paris", "Paris"]
104
- references = ["Paris", "London"]
105
- em = evaluator.exact_match(predictions, references)
106
- assert em == 0.5, "Expected 0.5, got {}".format(em)
107
- print("+ QA evaluation works")
108
-
109
- except Exception as e:
110
- print("✗ QAEvaluator test failed:", e)
111
- return False
112
-
113
- return True
114
-
115
- def test_config():
116
- """Test configuration loading"""
117
- print("\nTesting configuration...")
118
-
119
- try:
120
- import yaml
121
- with open('config.yaml', 'r') as f:
122
- config = yaml.safe_load(f)
123
-
124
- # Check required sections
125
- required_sections = ['models', 'data', 'index', 'retrieval', 'calibration', 'evaluation']
126
- for section in required_sections:
127
- assert section in config, "Missing config section: {}".format(section)
128
-
129
- print("+ Configuration file is valid")
130
- return True
131
-
132
- except Exception as e:
133
- print("✗ Configuration test failed:", e)
134
- return False
135
-
136
- def main():
137
- """Run all tests"""
138
- print("SafeRAG Simple Test Suite")
139
- print("=" * 40)
140
-
141
- all_passed = True
142
-
143
- # Test imports
144
- if not test_imports():
145
- all_passed = False
146
-
147
- # Test basic functionality
148
- if not test_basic_functionality():
149
- all_passed = False
150
-
151
- # Test configuration
152
- if not test_config():
153
- all_passed = False
154
-
155
- print("\n" + "=" * 40)
156
- if all_passed:
157
- print("+ All tests passed!")
158
- print("SafeRAG is ready to use.")
159
- else:
160
- print("✗ Some tests failed.")
161
- print("Please check the errors above.")
162
-
163
- return all_passed
164
-
165
- if __name__ == "__main__":
166
- success = main()
167
- sys.exit(0 if success else 1)