minhvtt commited on
Commit
8272622
·
verified ·
1 Parent(s): 1b5fb66

Delete test_advanced_features.py

Browse files
Files changed (1) hide show
  1. test_advanced_features.py +0 -260
test_advanced_features.py DELETED
@@ -1,260 +0,0 @@
1
- """
2
- Test script for Advanced RAG features
3
- Demonstrates new capabilities: multiple texts/images indexing and advanced RAG chat
4
- """
5
-
6
- import requests
7
- import json
8
- from typing import List, Optional
9
-
10
-
11
- class AdvancedRAGTester:
12
- """Test client for Advanced RAG API"""
13
-
14
- def __init__(self, base_url: str = "http://localhost:8000"):
15
- self.base_url = base_url
16
-
17
- def test_multiple_index(self, doc_id: str, texts: List[str], image_paths: Optional[List[str]] = None):
18
- """
19
- Test indexing with multiple texts and images
20
-
21
- Args:
22
- doc_id: Document ID
23
- texts: List of texts (max 10)
24
- image_paths: List of image file paths (max 10)
25
- """
26
- print(f"\n{'='*60}")
27
- print(f"TEST: Indexing document '{doc_id}' with multiple texts/images")
28
- print(f"{'='*60}")
29
-
30
- # Prepare form data
31
- data = {'id': doc_id}
32
-
33
- # Add texts
34
- if texts:
35
- if len(texts) > 10:
36
- print("WARNING: Maximum 10 texts allowed. Taking first 10.")
37
- texts = texts[:10]
38
- data['texts'] = texts
39
- print(f"✓ Texts: {len(texts)} items")
40
-
41
- # Prepare files
42
- files = []
43
- if image_paths:
44
- if len(image_paths) > 10:
45
- print("WARNING: Maximum 10 images allowed. Taking first 10.")
46
- image_paths = image_paths[:10]
47
-
48
- for img_path in image_paths:
49
- try:
50
- files.append(('images', open(img_path, 'rb')))
51
- except FileNotFoundError:
52
- print(f"WARNING: Image not found: {img_path}")
53
-
54
- print(f"✓ Images: {len(files)} files")
55
-
56
- # Make request
57
- try:
58
- response = requests.post(f"{self.base_url}/index", data=data, files=files)
59
- response.raise_for_status()
60
-
61
- result = response.json()
62
- print(f"\n✓ SUCCESS")
63
- print(f" - Document ID: {result['id']}")
64
- print(f" - Message: {result['message']}")
65
- return result
66
-
67
- except requests.exceptions.RequestException as e:
68
- print(f"\n✗ ERROR: {e}")
69
- if hasattr(e.response, 'text'):
70
- print(f" Response: {e.response.text}")
71
- return None
72
-
73
- finally:
74
- # Close file handles
75
- for _, file_obj in files:
76
- file_obj.close()
77
-
78
- def test_advanced_rag_chat(
79
- self,
80
- message: str,
81
- hf_token: Optional[str] = None,
82
- use_advanced_rag: bool = True,
83
- use_reranking: bool = True,
84
- use_compression: bool = True,
85
- top_k: int = 3,
86
- score_threshold: float = 0.5
87
- ):
88
- """
89
- Test advanced RAG chat
90
-
91
- Args:
92
- message: User question
93
- hf_token: Hugging Face token (optional)
94
- use_advanced_rag: Use advanced RAG pipeline
95
- use_reranking: Enable reranking
96
- use_compression: Enable context compression
97
- top_k: Number of documents to retrieve
98
- score_threshold: Minimum relevance score
99
- """
100
- print(f"\n{'='*60}")
101
- print(f"TEST: Advanced RAG Chat")
102
- print(f"{'='*60}")
103
- print(f"Question: {message}")
104
- print(f"Advanced RAG: {use_advanced_rag}")
105
- print(f"Reranking: {use_reranking}")
106
- print(f"Compression: {use_compression}")
107
-
108
- payload = {
109
- 'message': message,
110
- 'use_rag': True,
111
- 'use_advanced_rag': use_advanced_rag,
112
- 'use_reranking': use_reranking,
113
- 'use_compression': use_compression,
114
- 'top_k': top_k,
115
- 'score_threshold': score_threshold,
116
- }
117
-
118
- if hf_token:
119
- payload['hf_token'] = hf_token
120
-
121
- try:
122
- response = requests.post(f"{self.base_url}/chat", json=payload)
123
- response.raise_for_status()
124
-
125
- result = response.json()
126
-
127
- print(f"\n✓ SUCCESS")
128
- print(f"\n--- Answer ---")
129
- print(result['response'])
130
-
131
- print(f"\n--- Retrieved Context ({len(result['context_used'])} documents) ---")
132
- for i, ctx in enumerate(result['context_used'], 1):
133
- print(f"{i}. [{ctx['id']}] Confidence: {ctx['confidence']:.2%}")
134
- text_preview = ctx['metadata'].get('text', '')[:100]
135
- print(f" Text: {text_preview}...")
136
-
137
- if result.get('rag_stats'):
138
- print(f"\n--- RAG Pipeline Statistics ---")
139
- stats = result['rag_stats']
140
- print(f" Original query: {stats.get('original_query')}")
141
- print(f" Expanded queries: {stats.get('expanded_queries')}")
142
- print(f" Initial results: {stats.get('initial_results')}")
143
- print(f" After reranking: {stats.get('after_rerank')}")
144
- print(f" After compression: {stats.get('after_compression')}")
145
-
146
- return result
147
-
148
- except requests.exceptions.RequestException as e:
149
- print(f"\n✗ ERROR: {e}")
150
- if hasattr(e.response, 'text'):
151
- print(f" Response: {e.response.text}")
152
- return None
153
-
154
- def compare_basic_vs_advanced_rag(self, message: str, hf_token: Optional[str] = None):
155
- """Compare basic RAG vs advanced RAG side by side"""
156
- print(f"\n{'='*60}")
157
- print(f"COMPARISON: Basic RAG vs Advanced RAG")
158
- print(f"{'='*60}")
159
- print(f"Question: {message}\n")
160
-
161
- # Test Basic RAG
162
- print("\n--- BASIC RAG ---")
163
- basic_result = self.test_advanced_rag_chat(
164
- message=message,
165
- hf_token=hf_token,
166
- use_advanced_rag=False
167
- )
168
-
169
- # Test Advanced RAG
170
- print("\n--- ADVANCED RAG ---")
171
- advanced_result = self.test_advanced_rag_chat(
172
- message=message,
173
- hf_token=hf_token,
174
- use_advanced_rag=True
175
- )
176
-
177
- # Compare
178
- print(f"\n{'='*60}")
179
- print("COMPARISON SUMMARY")
180
- print(f"{'='*60}")
181
-
182
- if basic_result and advanced_result:
183
- print(f"Basic RAG:")
184
- print(f" - Retrieved docs: {len(basic_result['context_used'])}")
185
-
186
- print(f"\nAdvanced RAG:")
187
- print(f" - Retrieved docs: {len(advanced_result['context_used'])}")
188
- if advanced_result.get('rag_stats'):
189
- stats = advanced_result['rag_stats']
190
- print(f" - Query expansion: {len(stats.get('expanded_queries', []))} variants")
191
- print(f" - Initial retrieval: {stats.get('initial_results', 0)} docs")
192
- print(f" - After reranking: {stats.get('after_rerank', 0)} docs")
193
-
194
-
195
- def main():
196
- """Run tests"""
197
- tester = AdvancedRAGTester()
198
-
199
- print("="*60)
200
- print("ADVANCED RAG FEATURE TESTS")
201
- print("="*60)
202
-
203
- # Test 1: Index with multiple texts (no images for demo)
204
- print("\n\n### TEST 1: Index Multiple Texts ###")
205
- tester.test_multiple_index(
206
- doc_id="event_music_festival_2025",
207
- texts=[
208
- "Festival âm nhạc quốc tế Hà Nội 2025",
209
- "Thời gian: 15-17 tháng 11 năm 2025",
210
- "Địa điểm: Công viên Thống Nhất, Hà Nội",
211
- "Line-up: Sơn Tùng MTP, Đen Vâu, Hoàng Thùy Linh, Mỹ Tâm",
212
- "Giá vé: Early bird 500.000đ, VIP 2.000.000đ",
213
- "Dự kiến 50.000 khán giả tham dự",
214
- "3 sân khấu chính, 5 food court, khu vực cắm trại"
215
- ]
216
- )
217
-
218
- # Test 2: Index another document
219
- print("\n\n### TEST 2: Index Another Document ###")
220
- tester.test_multiple_index(
221
- doc_id="safety_guidelines",
222
- texts=[
223
- "Vũ khí và đồ vật nguy hiểm bị cấm mang vào sự kiện",
224
- "Dao, kiếm, súng và các loại vũ khí nguy hiểm nghiêm cấm",
225
- "An ninh sẽ kiểm tra tất cả túi xách và đồ mang theo",
226
- "Vi phạm sẽ bị tịch thu và có thể bị trục xuất khỏi sự kiện"
227
- ]
228
- )
229
-
230
- # Test 3: Basic chat (without HF token - will show placeholder)
231
- print("\n\n### TEST 3: Basic RAG Chat (No LLM) ###")
232
- tester.test_advanced_rag_chat(
233
- message="Festival Hà Nội diễn ra khi nào?",
234
- use_advanced_rag=False
235
- )
236
-
237
- # Test 4: Advanced RAG chat
238
- print("\n\n### TEST 4: Advanced RAG Chat (No LLM) ###")
239
- tester.test_advanced_rag_chat(
240
- message="Festival Hà Nội diễn ra khi nào và có những nghệ sĩ nào?",
241
- use_advanced_rag=True,
242
- use_reranking=True,
243
- use_compression=True
244
- )
245
-
246
- # Test 5: Compare basic vs advanced
247
- print("\n\n### TEST 5: Comparison Test ###")
248
- tester.compare_basic_vs_advanced_rag(
249
- message="Dao có được mang vào sự kiện không?"
250
- )
251
-
252
- print("\n\n" + "="*60)
253
- print("ALL TESTS COMPLETED")
254
- print("="*60)
255
- print("\nNOTE: To test with actual LLM responses, add your Hugging Face token:")
256
- print(" tester.test_advanced_rag_chat(message='...', hf_token='hf_xxxxx')")
257
-
258
-
259
- if __name__ == "__main__":
260
- main()