Tairun Meng commited on
Commit
8a3396b
Β·
1 Parent(s): db06013
Files changed (2) hide show
  1. real_embedding_test.py +269 -0
  2. retriever/embedder.py +1 -1
real_embedding_test.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ SafeRAG Real Embedding Test
5
+ Load data -> Generate real embeddings using sentence-transformers -> Build index -> Retrieve
6
+ """
7
+
8
+ import sys
9
+ import os
10
+ import time
11
+ import numpy as np
12
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
13
+
14
+ def test_real_embedding_pipeline():
15
+ """Test the complete pipeline with real embeddings"""
16
+ print("SafeRAG Real Embedding Pipeline Test")
17
+ print("=" * 50)
18
+
19
+ try:
20
+ # Step 1: Load data
21
+ print("\n1. Loading data...")
22
+ from data_processing import DataLoader, Preprocessor
23
+
24
+ loader = DataLoader()
25
+ preprocessor = Preprocessor()
26
+
27
+ # Load knowledge base
28
+ kb_passages = loader.get_knowledge_base()
29
+ print(f" βœ“ Loaded {len(kb_passages)} knowledge base passages")
30
+
31
+ # Show sample passages
32
+ for i, passage in enumerate(kb_passages):
33
+ print(f" [{i+1}] {passage}")
34
+
35
+ # Preprocess passages
36
+ processed_passages = preprocessor.preprocess_passages(kb_passages)
37
+ print(f" βœ“ Preprocessed {len(processed_passages)} passages")
38
+
39
+ # Step 2: Generate real embeddings
40
+ print("\n2. Generating real embeddings with sentence-transformers...")
41
+ from retriever import Embedder
42
+
43
+ # Use a smaller model for faster testing
44
+ embedder = Embedder(model_name="all-MiniLM-L6-v2", device="cpu")
45
+ print(f" βœ“ Loaded embedding model: {embedder.model_name}")
46
+ print(f" βœ“ Embedding dimension: {embedder.get_dimension()}")
47
+
48
+ # Extract text from processed passages
49
+ passage_texts = [p['text'] for p in processed_passages]
50
+
51
+ # Generate embeddings
52
+ start_time = time.time()
53
+ embeddings = embedder.encode_passages(passage_texts)
54
+ embedding_time = time.time() - start_time
55
+
56
+ print(f" βœ“ Generated {embeddings.shape[0]} embeddings in {embedding_time:.3f}s")
57
+ print(f" βœ“ Embedding shape: {embeddings.shape}")
58
+ print(f" βœ“ Embedding type: {type(embeddings)}")
59
+
60
+ # Show embedding statistics
61
+ print(f" βœ“ Embedding stats:")
62
+ print(f" - Mean: {np.mean(embeddings):.4f}")
63
+ print(f" - Std: {np.std(embeddings):.4f}")
64
+ print(f" - Min: {np.min(embeddings):.4f}")
65
+ print(f" - Max: {np.max(embeddings):.4f}")
66
+
67
+ # Step 3: Build FAISS index
68
+ print("\n3. Building FAISS index...")
69
+ from retriever import FAISSIndex
70
+
71
+ index = FAISSIndex(embedder.get_dimension())
72
+ start_time = time.time()
73
+ index.build_index(embeddings, passage_texts)
74
+ build_time = time.time() - start_time
75
+
76
+ print(f" βœ“ Built FAISS index in {build_time:.3f}s")
77
+ print(f" βœ“ Index contains {index.index.ntotal} vectors")
78
+
79
+ # Step 4: Test retrieval
80
+ print("\n4. Testing retrieval...")
81
+ from retriever import Retriever
82
+
83
+ retriever = Retriever(embedder, index, None) # No reranker for simplicity
84
+
85
+ test_queries = [
86
+ "What is machine learning?",
87
+ "Tell me about the capital of France",
88
+ "How does Python work?",
89
+ "What is artificial intelligence?"
90
+ ]
91
+
92
+ for query in test_queries:
93
+ print(f"\n Query: '{query}'")
94
+ start_time = time.time()
95
+ results = retriever.retrieve_single(query, k=3)
96
+ retrieval_time = time.time() - start_time
97
+
98
+ print(f" βœ“ Retrieved {len(results)} passages in {retrieval_time:.3f}s")
99
+ for i, result in enumerate(results):
100
+ print(f" [{i+1}] Score: {result['score']:.4f}")
101
+ print(f" Text: {result['text'][:100]}...")
102
+
103
+ # Step 5: Test similarity calculation
104
+ print("\n5. Testing similarity calculation...")
105
+
106
+ # Test query-passage similarity
107
+ query = "What is machine learning?"
108
+ query_embedding = embedder.encode_queries([query])[0]
109
+
110
+ print(f" Query: '{query}'")
111
+ print(f" Query embedding shape: {query_embedding.shape}")
112
+
113
+ # Calculate similarities with all passages
114
+ similarities = []
115
+ for i, passage_embedding in enumerate(embeddings):
116
+ # Cosine similarity
117
+ similarity = np.dot(query_embedding, passage_embedding) / (
118
+ np.linalg.norm(query_embedding) * np.linalg.norm(passage_embedding)
119
+ )
120
+ similarities.append((i, similarity, passage_texts[i]))
121
+
122
+ # Sort by similarity
123
+ similarities.sort(key=lambda x: x[1], reverse=True)
124
+
125
+ print(f" βœ“ Calculated similarities with {len(similarities)} passages")
126
+ print(f" Top 3 most similar passages:")
127
+ for i, (idx, sim, text) in enumerate(similarities[:3]):
128
+ print(f" [{i+1}] Similarity: {sim:.4f}")
129
+ print(f" Text: {text[:80]}...")
130
+
131
+ # Step 6: Test generation
132
+ print("\n6. Testing generation...")
133
+ from generator import SafeGenerator, PromptTemplates
134
+
135
+ templates = PromptTemplates()
136
+ generator = SafeGenerator(None, None, 0.3, 0.7) # Simplified version
137
+
138
+ test_query = "What is machine learning?"
139
+ retrieved_passages = retriever.retrieve_single(test_query, k=3)
140
+
141
+ print(f" Query: '{test_query}'")
142
+ print(f" Retrieved {len(retrieved_passages)} passages")
143
+
144
+ # Generate answer
145
+ start_time = time.time()
146
+ result = generator.generate_with_strategy(test_query, retrieved_passages)
147
+ generation_time = time.time() - start_time
148
+
149
+ print(f" βœ“ Generated answer in {generation_time:.3f}s")
150
+ print(f" Answer: {result['answer'][:200]}...")
151
+ print(f" Risk Score: {result['risk_score']:.3f}")
152
+ print(f" Strategy: {result['strategy']}")
153
+
154
+ print("\n" + "=" * 50)
155
+ print("πŸŽ‰ Real embedding pipeline test completed successfully!")
156
+ print("\nPipeline Summary:")
157
+ print(f"- Data Loading: {len(kb_passages)} passages")
158
+ print(f"- Real Embedding Generation: {embeddings.shape[0]} vectors ({embeddings.shape[1]}D)")
159
+ print(f"- Index Building: {index.index.ntotal} indexed vectors")
160
+ print(f"- Retrieval: {len(test_queries)} test queries")
161
+ print(f"- Similarity Calculation: Cosine similarity with all passages")
162
+ print(f"- Generation: Risk-aware answer generation")
163
+
164
+ return True
165
+
166
+ except Exception as e:
167
+ print(f"\n❌ Pipeline test failed: {e}")
168
+ import traceback
169
+ traceback.print_exc()
170
+ return False
171
+
172
+ def test_embedding_quality():
173
+ """Test embedding quality and properties"""
174
+ print("\n" + "=" * 50)
175
+ print("Testing Embedding Quality")
176
+ print("=" * 50)
177
+
178
+ try:
179
+ from retriever import Embedder
180
+
181
+ # Initialize embedder
182
+ embedder = Embedder(model_name="all-MiniLM-L6-v2", device="cpu")
183
+
184
+ # Test texts
185
+ test_texts = [
186
+ "Machine learning is a subset of artificial intelligence",
187
+ "The capital of France is Paris",
188
+ "Python is a programming language",
189
+ "Machine learning algorithms learn from data", # Similar to first
190
+ "Paris is the capital city of France", # Similar to second
191
+ ]
192
+
193
+ print("1. Generating embeddings for test texts...")
194
+ embeddings = embedder.encode(test_texts)
195
+ print(f" βœ“ Generated {embeddings.shape[0]} embeddings")
196
+
197
+ print("\n2. Testing similarity between related texts...")
198
+
199
+ # Test similarity between related texts
200
+ pairs = [
201
+ (0, 3, "Machine learning texts"),
202
+ (1, 4, "France/Paris texts"),
203
+ ]
204
+
205
+ for i, j, description in pairs:
206
+ sim = np.dot(embeddings[i], embeddings[j]) / (
207
+ np.linalg.norm(embeddings[i]) * np.linalg.norm(embeddings[j])
208
+ )
209
+ print(f" {description}: {sim:.4f}")
210
+ print(f" Text 1: {test_texts[i]}")
211
+ print(f" Text 2: {test_texts[j]}")
212
+
213
+ print("\n3. Testing embedding properties...")
214
+
215
+ # Check if embeddings are normalized
216
+ norms = [np.linalg.norm(emb) for emb in embeddings]
217
+ print(f" βœ“ Embedding norms: {[f'{n:.4f}' for n in norms]}")
218
+
219
+ # Check embedding statistics
220
+ all_embeddings = embeddings.flatten()
221
+ print(f" βœ“ All embedding values:")
222
+ print(f" - Mean: {np.mean(all_embeddings):.4f}")
223
+ print(f" - Std: {np.std(all_embeddings):.4f}")
224
+ print(f" - Min: {np.min(all_embeddings):.4f}")
225
+ print(f" - Max: {np.max(all_embeddings):.4f}")
226
+
227
+ print("\nβœ… Embedding quality test completed!")
228
+ return True
229
+
230
+ except Exception as e:
231
+ print(f"\n❌ Embedding quality test failed: {e}")
232
+ import traceback
233
+ traceback.print_exc()
234
+ return False
235
+
236
+ def main():
237
+ """Run all tests"""
238
+ print("SafeRAG Real Embedding Test Suite")
239
+ print("=" * 60)
240
+
241
+ success = True
242
+
243
+ # Test embedding quality
244
+ if not test_embedding_quality():
245
+ success = False
246
+
247
+ # Test real embedding pipeline
248
+ if not test_real_embedding_pipeline():
249
+ success = False
250
+
251
+ print("\n" + "=" * 60)
252
+ if success:
253
+ print("πŸŽ‰ All real embedding tests passed!")
254
+ print("\nThe system can now:")
255
+ print("1. βœ… Load data from knowledge base")
256
+ print("2. βœ… Generate real embeddings using sentence-transformers")
257
+ print("3. βœ… Build FAISS index with real embeddings")
258
+ print("4. βœ… Retrieve relevant passages using real similarity")
259
+ print("5. βœ… Calculate cosine similarity between queries and passages")
260
+ print("6. βœ… Generate answers based on retrieved passages")
261
+ print("7. βœ… Assess embedding quality and properties")
262
+ else:
263
+ print("❌ Some tests failed. Please check the errors above.")
264
+
265
+ return success
266
+
267
+ if __name__ == "__main__":
268
+ success = main()
269
+ sys.exit(0 if success else 1)
retriever/embedder.py CHANGED
@@ -46,4 +46,4 @@ class Embedder:
46
 
47
  def get_dimension(self) -> int:
48
  """Get embedding dimension"""
49
- return self.model.get_sentence_embedding_dimension()
 
46
 
47
  def get_dimension(self) -> int:
48
  """Get embedding dimension"""
49
+ return self.model.get_sentence_embedding_dimension()