| | """ |
| | Test script for Advanced RAG features |
| | Demonstrates new capabilities: multiple texts/images indexing and advanced RAG chat |
| | """ |
| |
|
| | import requests |
| | import json |
| | from typing import List, Optional |
| |
|
| |
|
| | class AdvancedRAGTester: |
| | """Test client for Advanced RAG API""" |
| |
|
| | def __init__(self, base_url: str = "http://localhost:8000"): |
| | self.base_url = base_url |
| |
|
| | def test_multiple_index(self, doc_id: str, texts: List[str], image_paths: Optional[List[str]] = None): |
| | """ |
| | Test indexing with multiple texts and images |
| | |
| | Args: |
| | doc_id: Document ID |
| | texts: List of texts (max 10) |
| | image_paths: List of image file paths (max 10) |
| | """ |
| | print(f"\n{'='*60}") |
| | print(f"TEST: Indexing document '{doc_id}' with multiple texts/images") |
| | print(f"{'='*60}") |
| |
|
| | |
| | data = {'id': doc_id} |
| |
|
| | |
| | if texts: |
| | if len(texts) > 10: |
| | print("WARNING: Maximum 10 texts allowed. Taking first 10.") |
| | texts = texts[:10] |
| | data['texts'] = texts |
| | print(f"✓ Texts: {len(texts)} items") |
| |
|
| | |
| | files = [] |
| | if image_paths: |
| | if len(image_paths) > 10: |
| | print("WARNING: Maximum 10 images allowed. Taking first 10.") |
| | image_paths = image_paths[:10] |
| |
|
| | for img_path in image_paths: |
| | try: |
| | files.append(('images', open(img_path, 'rb'))) |
| | except FileNotFoundError: |
| | print(f"WARNING: Image not found: {img_path}") |
| |
|
| | print(f"✓ Images: {len(files)} files") |
| |
|
| | |
| | try: |
| | response = requests.post(f"{self.base_url}/index", data=data, files=files) |
| | response.raise_for_status() |
| |
|
| | result = response.json() |
| | print(f"\n✓ SUCCESS") |
| | print(f" - Document ID: {result['id']}") |
| | print(f" - Message: {result['message']}") |
| | return result |
| |
|
| | except requests.exceptions.RequestException as e: |
| | print(f"\n✗ ERROR: {e}") |
| | if hasattr(e.response, 'text'): |
| | print(f" Response: {e.response.text}") |
| | return None |
| |
|
| | finally: |
| | |
| | for _, file_obj in files: |
| | file_obj.close() |
| |
|
| | def test_advanced_rag_chat( |
| | self, |
| | message: str, |
| | hf_token: Optional[str] = None, |
| | use_advanced_rag: bool = True, |
| | use_reranking: bool = True, |
| | use_compression: bool = True, |
| | top_k: int = 3, |
| | score_threshold: float = 0.5 |
| | ): |
| | """ |
| | Test advanced RAG chat |
| | |
| | Args: |
| | message: User question |
| | hf_token: Hugging Face token (optional) |
| | use_advanced_rag: Use advanced RAG pipeline |
| | use_reranking: Enable reranking |
| | use_compression: Enable context compression |
| | top_k: Number of documents to retrieve |
| | score_threshold: Minimum relevance score |
| | """ |
| | print(f"\n{'='*60}") |
| | print(f"TEST: Advanced RAG Chat") |
| | print(f"{'='*60}") |
| | print(f"Question: {message}") |
| | print(f"Advanced RAG: {use_advanced_rag}") |
| | print(f"Reranking: {use_reranking}") |
| | print(f"Compression: {use_compression}") |
| |
|
| | payload = { |
| | 'message': message, |
| | 'use_rag': True, |
| | 'use_advanced_rag': use_advanced_rag, |
| | 'use_reranking': use_reranking, |
| | 'use_compression': use_compression, |
| | 'top_k': top_k, |
| | 'score_threshold': score_threshold, |
| | } |
| |
|
| | if hf_token: |
| | payload['hf_token'] = hf_token |
| |
|
| | try: |
| | response = requests.post(f"{self.base_url}/chat", json=payload) |
| | response.raise_for_status() |
| |
|
| | result = response.json() |
| |
|
| | print(f"\n✓ SUCCESS") |
| | print(f"\n--- Answer ---") |
| | print(result['response']) |
| |
|
| | print(f"\n--- Retrieved Context ({len(result['context_used'])} documents) ---") |
| | for i, ctx in enumerate(result['context_used'], 1): |
| | print(f"{i}. [{ctx['id']}] Confidence: {ctx['confidence']:.2%}") |
| | text_preview = ctx['metadata'].get('text', '')[:100] |
| | print(f" Text: {text_preview}...") |
| |
|
| | if result.get('rag_stats'): |
| | print(f"\n--- RAG Pipeline Statistics ---") |
| | stats = result['rag_stats'] |
| | print(f" Original query: {stats.get('original_query')}") |
| | print(f" Expanded queries: {stats.get('expanded_queries')}") |
| | print(f" Initial results: {stats.get('initial_results')}") |
| | print(f" After reranking: {stats.get('after_rerank')}") |
| | print(f" After compression: {stats.get('after_compression')}") |
| |
|
| | return result |
| |
|
| | except requests.exceptions.RequestException as e: |
| | print(f"\n✗ ERROR: {e}") |
| | if hasattr(e.response, 'text'): |
| | print(f" Response: {e.response.text}") |
| | return None |
| |
|
| | def compare_basic_vs_advanced_rag(self, message: str, hf_token: Optional[str] = None): |
| | """Compare basic RAG vs advanced RAG side by side""" |
| | print(f"\n{'='*60}") |
| | print(f"COMPARISON: Basic RAG vs Advanced RAG") |
| | print(f"{'='*60}") |
| | print(f"Question: {message}\n") |
| |
|
| | |
| | print("\n--- BASIC RAG ---") |
| | basic_result = self.test_advanced_rag_chat( |
| | message=message, |
| | hf_token=hf_token, |
| | use_advanced_rag=False |
| | ) |
| |
|
| | |
| | print("\n--- ADVANCED RAG ---") |
| | advanced_result = self.test_advanced_rag_chat( |
| | message=message, |
| | hf_token=hf_token, |
| | use_advanced_rag=True |
| | ) |
| |
|
| | |
| | print(f"\n{'='*60}") |
| | print("COMPARISON SUMMARY") |
| | print(f"{'='*60}") |
| |
|
| | if basic_result and advanced_result: |
| | print(f"Basic RAG:") |
| | print(f" - Retrieved docs: {len(basic_result['context_used'])}") |
| |
|
| | print(f"\nAdvanced RAG:") |
| | print(f" - Retrieved docs: {len(advanced_result['context_used'])}") |
| | if advanced_result.get('rag_stats'): |
| | stats = advanced_result['rag_stats'] |
| | print(f" - Query expansion: {len(stats.get('expanded_queries', []))} variants") |
| | print(f" - Initial retrieval: {stats.get('initial_results', 0)} docs") |
| | print(f" - After reranking: {stats.get('after_rerank', 0)} docs") |
| |
|
| |
|
| | def main(): |
| | """Run tests""" |
| | tester = AdvancedRAGTester() |
| |
|
| | print("="*60) |
| | print("ADVANCED RAG FEATURE TESTS") |
| | print("="*60) |
| |
|
| | |
| | print("\n\n### TEST 1: Index Multiple Texts ###") |
| | tester.test_multiple_index( |
| | doc_id="event_music_festival_2025", |
| | texts=[ |
| | "Festival âm nhạc quốc tế Hà Nội 2025", |
| | "Thời gian: 15-17 tháng 11 năm 2025", |
| | "Địa điểm: Công viên Thống Nhất, Hà Nội", |
| | "Line-up: Sơn Tùng MTP, Đen Vâu, Hoàng Thùy Linh, Mỹ Tâm", |
| | "Giá vé: Early bird 500.000đ, VIP 2.000.000đ", |
| | "Dự kiến 50.000 khán giả tham dự", |
| | "3 sân khấu chính, 5 food court, khu vực cắm trại" |
| | ] |
| | ) |
| |
|
| | |
| | print("\n\n### TEST 2: Index Another Document ###") |
| | tester.test_multiple_index( |
| | doc_id="safety_guidelines", |
| | texts=[ |
| | "Vũ khí và đồ vật nguy hiểm bị cấm mang vào sự kiện", |
| | "Dao, kiếm, súng và các loại vũ khí nguy hiểm nghiêm cấm", |
| | "An ninh sẽ kiểm tra tất cả túi xách và đồ mang theo", |
| | "Vi phạm sẽ bị tịch thu và có thể bị trục xuất khỏi sự kiện" |
| | ] |
| | ) |
| |
|
| | |
| | print("\n\n### TEST 3: Basic RAG Chat (No LLM) ###") |
| | tester.test_advanced_rag_chat( |
| | message="Festival Hà Nội diễn ra khi nào?", |
| | use_advanced_rag=False |
| | ) |
| |
|
| | |
| | print("\n\n### TEST 4: Advanced RAG Chat (No LLM) ###") |
| | tester.test_advanced_rag_chat( |
| | message="Festival Hà Nội diễn ra khi nào và có những nghệ sĩ nào?", |
| | use_advanced_rag=True, |
| | use_reranking=True, |
| | use_compression=True |
| | ) |
| |
|
| | |
| | print("\n\n### TEST 5: Comparison Test ###") |
| | tester.compare_basic_vs_advanced_rag( |
| | message="Dao có được mang vào sự kiện không?" |
| | ) |
| |
|
| | print("\n\n" + "="*60) |
| | print("ALL TESTS COMPLETED") |
| | print("="*60) |
| | print("\nNOTE: To test with actual LLM responses, add your Hugging Face token:") |
| | print(" tester.test_advanced_rag_chat(message='...', hf_token='hf_xxxxx')") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|