ChiragPatankar commited on
Commit
2ed3c40
·
verified ·
1 Parent(s): 16001d7

Upload scripts/validate_rag.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/validate_rag.py +466 -0
scripts/validate_rag.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Automated RAG pipeline validation script.
3
+ Tests end-to-end functionality, multi-tenant isolation, and anti-hallucination.
4
+ """
5
+ import httpx
6
+ import time
7
+ import json
8
+ from pathlib import Path
9
+ from typing import Dict, List, Any, Tuple
10
+ import sys
11
+
12
+ # Add parent directory to path
13
+ sys.path.insert(0, str(Path(__file__).parent.parent))
14
+
15
+ BASE_URL = "http://localhost:8000"
16
+ TEST_TENANT_A = "tenant_A"
17
+ TEST_TENANT_B = "tenant_B"
18
+ TEST_USER_A = "user_A"
19
+ TEST_USER_B = "user_B"
20
+ TEST_KB_A = "kb_A"
21
+ TEST_KB_B = "kb_B"
22
+
23
+ # Test documents
24
+ TENANT_A_DOC = Path(__file__).parent.parent / "data" / "test_docs" / "tenant_A_kb.md"
25
+ TENANT_B_DOC = Path(__file__).parent.parent / "data" / "test_docs" / "tenant_B_kb.md"
26
+
27
+ # Test results storage
28
+ test_results: List[Dict[str, Any]] = []
29
+
30
+
31
+ def print_header(text: str):
32
+ """Print a formatted header."""
33
+ print("\n" + "=" * 80)
34
+ print(f" {text}")
35
+ print("=" * 80)
36
+
37
+
38
+ def print_test(test_name: str, passed: bool, reason: str = ""):
39
+ """Print test result."""
40
+ status = "[PASS]" if passed else "[FAIL]"
41
+ print(f"{status} | {test_name}")
42
+ if reason:
43
+ print(f" └─ {reason}")
44
+ test_results.append({
45
+ "test": test_name,
46
+ "passed": passed,
47
+ "reason": reason
48
+ })
49
+
50
+
51
+ def wait_for_server(max_retries: int = 10, delay: int = 2) -> bool:
52
+ """Wait for the server to be ready."""
53
+ print("Waiting for server to be ready...")
54
+ for i in range(max_retries):
55
+ try:
56
+ response = httpx.get(f"{BASE_URL}/health", timeout=5)
57
+ if response.status_code == 200:
58
+ print("[OK] Server is ready")
59
+ return True
60
+ except Exception:
61
+ pass
62
+ time.sleep(delay)
63
+ print(f" Retry {i+1}/{max_retries}...")
64
+ print("[FAIL] Server not ready after max retries")
65
+ return False
66
+
67
+
68
+ def upload_document(
69
+ client: httpx.Client,
70
+ file_path: Path,
71
+ tenant_id: str,
72
+ user_id: str,
73
+ kb_id: str
74
+ ) -> Dict[str, Any]:
75
+ """Upload a document to the knowledge base."""
76
+ try:
77
+ with open(file_path, "rb") as f:
78
+ files = {"file": (file_path.name, f, "text/markdown")}
79
+ data = {
80
+ "tenant_id": tenant_id,
81
+ "user_id": user_id,
82
+ "kb_id": kb_id
83
+ }
84
+ response = client.post(
85
+ f"{BASE_URL}/kb/upload",
86
+ files=files,
87
+ data=data,
88
+ timeout=60
89
+ )
90
+ if response.status_code == 200:
91
+ return {"success": True, "data": response.json()}
92
+ else:
93
+ return {"success": False, "error": response.text}
94
+ except Exception as e:
95
+ return {"success": False, "error": str(e)}
96
+
97
+
98
+ def test_retrieval(
99
+ client: httpx.Client,
100
+ query: str,
101
+ tenant_id: str,
102
+ user_id: str,
103
+ kb_id: str,
104
+ expected_keywords: List[str],
105
+ should_not_contain: List[str] = None,
106
+ top_k: int = 5
107
+ ) -> Tuple[bool, str]:
108
+ """Test retrieval accuracy."""
109
+ try:
110
+ # Use GET for search endpoint with headers for dev mode auth
111
+ headers = {
112
+ "X-Tenant-Id": tenant_id,
113
+ "X-User-Id": user_id
114
+ }
115
+ response = client.get(
116
+ f"{BASE_URL}/kb/search",
117
+ params={
118
+ "query": query,
119
+ "kb_id": kb_id,
120
+ "top_k": top_k
121
+ },
122
+ headers=headers,
123
+ timeout=30
124
+ )
125
+
126
+ if response.status_code != 200:
127
+ return False, f"API returned {response.status_code}: {response.text}"
128
+
129
+ data = response.json()
130
+ results = data.get("results", [])
131
+
132
+ if not results:
133
+ return False, "No results retrieved"
134
+
135
+ # Check tenant isolation
136
+ for result in results:
137
+ metadata = result.get("metadata", {})
138
+ result_tenant = metadata.get("tenant_id")
139
+ if result_tenant != tenant_id:
140
+ return False, f"Tenant leak detected! Got tenant_id={result_tenant}, expected {tenant_id}"
141
+
142
+ # Check for expected keywords
143
+ all_content = " ".join([r.get("content", "") for r in results]).lower()
144
+ found_keywords = [kw for kw in expected_keywords if kw.lower() in all_content]
145
+
146
+ if not found_keywords:
147
+ return False, f"Expected keywords not found: {expected_keywords}"
148
+
149
+ # Check for forbidden content
150
+ if should_not_contain:
151
+ for forbidden in should_not_contain:
152
+ if forbidden.lower() in all_content:
153
+ return False, f"Forbidden content found: {forbidden}"
154
+
155
+ return True, f"Retrieved {len(results)} results, found keywords: {found_keywords}"
156
+
157
+ except Exception as e:
158
+ return False, f"Error: {str(e)}"
159
+
160
+
161
+ def test_chat(
162
+ client: httpx.Client,
163
+ question: str,
164
+ tenant_id: str,
165
+ user_id: str,
166
+ kb_id: str,
167
+ expected_keywords: List[str] = None,
168
+ should_refuse: bool = False,
169
+ should_not_contain: List[str] = None
170
+ ) -> Tuple[bool, str, Dict[str, Any]]:
171
+ """Test full chat endpoint."""
172
+ try:
173
+ # Include headers for dev mode auth
174
+ headers = {
175
+ "X-Tenant-Id": tenant_id,
176
+ "X-User-Id": user_id
177
+ }
178
+ response = client.post(
179
+ f"{BASE_URL}/chat",
180
+ json={
181
+ "tenant_id": tenant_id,
182
+ "user_id": user_id,
183
+ "kb_id": kb_id,
184
+ "question": question
185
+ },
186
+ headers=headers,
187
+ timeout=60
188
+ )
189
+
190
+ if response.status_code != 200:
191
+ return False, f"API returned {response.status_code}: {response.text}", {}
192
+
193
+ data = response.json()
194
+ answer = data.get("answer", "").lower()
195
+ citations = data.get("citations", [])
196
+ from_kb = data.get("from_knowledge_base", False)
197
+ confidence = data.get("confidence", 0.0)
198
+ metadata = data.get("metadata", {})
199
+ refused = metadata.get("refused", False)
200
+
201
+ # Check refusal behavior (STRICT)
202
+ if should_refuse:
203
+ # Check if response explicitly indicates refusal
204
+ refused = data.get("refused", False)
205
+ refusal_keywords = [
206
+ "couldn't find", "don't have", "not available", "contact support",
207
+ "not in the knowledge base", "could not verify", "not enough information",
208
+ "apologize", "couldn't find relevant information"
209
+ ]
210
+ has_refusal_keywords = any(kw in answer for kw in refusal_keywords)
211
+
212
+ # If answer was generated with citations, it's a FAIL (should have refused)
213
+ if citations and len(citations) > 0:
214
+ return False, (
215
+ f"Should have refused but generated answer with {len(citations)} citations. "
216
+ f"Answer: {answer[:300]}"
217
+ ), data
218
+
219
+ # If confidence is high and answer exists, it's a FAIL
220
+ if confidence >= 0.30 and answer and not has_refusal_keywords:
221
+ return False, (
222
+ f"Should have refused but generated answer with confidence {confidence:.2f}. "
223
+ f"Answer: {answer[:300]}"
224
+ ), data
225
+
226
+ # If not refused and no refusal keywords, it's a FAIL
227
+ if not refused and not has_refusal_keywords:
228
+ return False, (
229
+ f"Should have refused but didn't. "
230
+ f"refused={refused}, confidence={confidence:.2f}, citations={len(citations)}. "
231
+ f"Answer: {answer[:300]}"
232
+ ), data
233
+
234
+ # If we got here, it properly refused
235
+ return True, f"Properly refused (refused={refused}, confidence={confidence:.2f})", data
236
+
237
+ # Check for expected keywords
238
+ if expected_keywords:
239
+ found = [kw for kw in expected_keywords if kw.lower() in answer]
240
+ if not found:
241
+ return False, f"Expected keywords not found: {expected_keywords}. Answer: {answer[:200]}", data
242
+
243
+ # Check citations
244
+ if not should_refuse and from_kb:
245
+ if not citations:
246
+ return False, "Answer claims to be from KB but has no citations", data
247
+
248
+ # Check for forbidden content
249
+ if should_not_contain:
250
+ for forbidden in should_not_contain:
251
+ if forbidden.lower() in answer:
252
+ return False, f"Forbidden content found in answer: {forbidden}", data
253
+
254
+ # Check citation integrity
255
+ if citations and expected_keywords:
256
+ citation_text = " ".join([c.get("excerpt", "") for c in citations]).lower()
257
+ for kw in expected_keywords:
258
+ if kw.lower() in answer and kw.lower() not in citation_text:
259
+ # This is a warning, not a failure
260
+ pass
261
+
262
+ return True, f"Answer generated (confidence: {confidence:.2f}, citations: {len(citations)})", data
263
+
264
+ except Exception as e:
265
+ return False, f"Error: {str(e)}", {}
266
+
267
+
268
+ def main():
269
+ """Run all validation tests."""
270
+ print_header("RAG Pipeline Validation Suite")
271
+
272
+ # Check server
273
+ if not wait_for_server():
274
+ print("[FAIL] Cannot proceed without server")
275
+ return
276
+
277
+ client = httpx.Client(timeout=120.0)
278
+
279
+ # ========== PHASE 1: Upload Documents ==========
280
+ print_header("Phase 1: Upload Test Documents")
281
+
282
+ # Upload tenant A doc
283
+ print(f"\n📤 Uploading {TENANT_A_DOC.name} for {TEST_TENANT_A}...")
284
+ result = upload_document(client, TENANT_A_DOC, TEST_TENANT_A, TEST_USER_A, TEST_KB_A)
285
+ if result["success"]:
286
+ print("[OK] Upload successful")
287
+ print("⏳ Waiting for document processing (10 seconds)...")
288
+ time.sleep(10) # Wait longer for processing (parsing, chunking, embedding)
289
+ else:
290
+ print(f"[FAIL] Upload failed: {result.get('error')}")
291
+ return
292
+
293
+ # Upload tenant B doc
294
+ print(f"\n📤 Uploading {TENANT_B_DOC.name} for {TEST_TENANT_B}...")
295
+ result = upload_document(client, TENANT_B_DOC, TEST_TENANT_B, TEST_USER_B, TEST_KB_B)
296
+ if result["success"]:
297
+ print("[OK] Upload successful")
298
+ print("⏳ Waiting for document processing (10 seconds)...")
299
+ time.sleep(10) # Wait longer for processing (parsing, chunking, embedding)
300
+ else:
301
+ print(f"[FAIL] Upload failed: {result.get('error')}")
302
+ return
303
+
304
+ # ========== PHASE 2: Retrieval Tests ==========
305
+ print_header("Phase 2: Retrieval Accuracy Tests")
306
+
307
+ # Test 1: Tenant A retrieval
308
+ passed, reason = test_retrieval(
309
+ client,
310
+ "What is the refund window?",
311
+ TEST_TENANT_A,
312
+ TEST_USER_A,
313
+ TEST_KB_A,
314
+ expected_keywords=["7 days"],
315
+ should_not_contain=["30 days"]
316
+ )
317
+ print_test("Retrieval: Tenant A - Refund Window", passed, reason)
318
+
319
+ # Test 2: Tenant B retrieval
320
+ passed, reason = test_retrieval(
321
+ client,
322
+ "What is the refund window?",
323
+ TEST_TENANT_B,
324
+ TEST_USER_B,
325
+ TEST_KB_B,
326
+ expected_keywords=["30 days"],
327
+ should_not_contain=["7 days"]
328
+ )
329
+ print_test("Retrieval: Tenant B - Refund Window", passed, reason)
330
+
331
+ # Test 3: Tenant isolation (A should not get B's data)
332
+ passed, reason = test_retrieval(
333
+ client,
334
+ "Starter plan price",
335
+ TEST_TENANT_A,
336
+ TEST_USER_A,
337
+ TEST_KB_A,
338
+ expected_keywords=["499"],
339
+ should_not_contain=["999"]
340
+ )
341
+ print_test("Retrieval: Tenant A - Starter Plan Price (Isolation)", passed, reason)
342
+
343
+ # Test 4: Tenant isolation (B should not get A's data)
344
+ passed, reason = test_retrieval(
345
+ client,
346
+ "Starter plan price",
347
+ TEST_TENANT_B,
348
+ TEST_USER_B,
349
+ TEST_KB_B,
350
+ expected_keywords=["999"],
351
+ should_not_contain=["499"]
352
+ )
353
+ print_test("Retrieval: Tenant B - Starter Plan Price (Isolation)", passed, reason)
354
+
355
+ # ========== PHASE 3: Chat Tests ==========
356
+ print_header("Phase 3: Chat Endpoint Tests")
357
+
358
+ # Test 5: Tenant A chat - refund window
359
+ passed, reason, data = test_chat(
360
+ client,
361
+ "What is the refund window?",
362
+ TEST_TENANT_A,
363
+ TEST_USER_A,
364
+ TEST_KB_A,
365
+ expected_keywords=["7 days"],
366
+ should_not_contain=["30 days"]
367
+ )
368
+ print_test("Chat: Tenant A - Refund Window", passed, reason)
369
+
370
+ # Test 6: Tenant B chat - refund window
371
+ passed, reason, data = test_chat(
372
+ client,
373
+ "What is the refund window?",
374
+ TEST_TENANT_B,
375
+ TEST_USER_B,
376
+ TEST_KB_B,
377
+ expected_keywords=["30 days"],
378
+ should_not_contain=["7 days"]
379
+ )
380
+ print_test("Chat: Tenant B - Refund Window", passed, reason)
381
+
382
+ # Test 7: Tenant A chat - Starter plan
383
+ passed, reason, data = test_chat(
384
+ client,
385
+ "What is the Starter plan price?",
386
+ TEST_TENANT_A,
387
+ TEST_USER_A,
388
+ TEST_KB_A,
389
+ expected_keywords=["499"],
390
+ should_not_contain=["999"]
391
+ )
392
+ print_test("Chat: Tenant A - Starter Plan Price", passed, reason)
393
+
394
+ # Test 8: Tenant B chat - Starter plan
395
+ passed, reason, data = test_chat(
396
+ client,
397
+ "What is the Starter plan price?",
398
+ TEST_TENANT_B,
399
+ TEST_USER_B,
400
+ TEST_KB_B,
401
+ expected_keywords=["999"],
402
+ should_not_contain=["499"]
403
+ )
404
+ print_test("Chat: Tenant B - Starter Plan Price", passed, reason)
405
+
406
+ # Test 9: Hallucination refusal - out of scope
407
+ passed, reason, data = test_chat(
408
+ client,
409
+ "How to integrate ClientSphere with Shopify?",
410
+ TEST_TENANT_A,
411
+ TEST_USER_A,
412
+ TEST_KB_A,
413
+ should_refuse=True
414
+ )
415
+ print_test("Chat: Hallucination Refusal (Out of Scope)", passed, reason)
416
+
417
+ # Test 10: Citation integrity
418
+ passed, reason, data = test_chat(
419
+ client,
420
+ "How long do password reset links last?",
421
+ TEST_TENANT_A,
422
+ TEST_USER_A,
423
+ TEST_KB_A,
424
+ expected_keywords=["15"]
425
+ )
426
+ if passed:
427
+ citations = data.get("citations", [])
428
+ if citations:
429
+ print_test("Chat: Citation Integrity", True, f"Found {len(citations)} citations")
430
+ else:
431
+ print_test("Chat: Citation Integrity", False, "No citations provided")
432
+ else:
433
+ print_test("Chat: Citation Integrity", False, reason)
434
+
435
+ # ========== PHASE 4: Summary ==========
436
+ print_header("Test Summary")
437
+
438
+ total_tests = len(test_results)
439
+ passed_tests = sum(1 for r in test_results if r["passed"])
440
+ failed_tests = total_tests - passed_tests
441
+
442
+ print(f"\nTotal Tests: {total_tests}")
443
+ print(f"[PASS] Passed: {passed_tests}")
444
+ print(f"[FAIL] Failed: {failed_tests}")
445
+ print(f"Success Rate: {(passed_tests/total_tests*100):.1f}%")
446
+
447
+ if failed_tests > 0:
448
+ print("\n[FAIL] Failed Tests:")
449
+ for result in test_results:
450
+ if not result["passed"]:
451
+ print(f" - {result['test']}: {result['reason']}")
452
+
453
+ # Final verdict
454
+ print_header("Final Verdict")
455
+ if failed_tests == 0:
456
+ print("[PASS] ALL TESTS PASSED - RAG Pipeline is working correctly")
457
+ return 0
458
+ else:
459
+ print(f"[FAIL] {failed_tests} TEST(S) FAILED - Review issues above")
460
+ return 1
461
+
462
+
463
+ if __name__ == "__main__":
464
+ exit_code = main()
465
+ sys.exit(exit_code)
466
+