File size: 8,341 Bytes
b02e301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))

from openai import OpenAI
from dotenv import load_dotenv
import os

load_dotenv(Path(__file__).parent.parent.parent.parent.parent / ".env", override=True)

from rag_system import QueryExpander, HybridRetriever, RAGSystem


def test_query_expansion():
    print("\n" + "="*60)
    print("TEST: Query Expansion")
    print("="*60)
    
    try:
        client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
        expander = QueryExpander(client)
        
        query = "What are your programming skills?"
        expanded = expander.expand_query(query, num_variations=2)
        
        assert isinstance(expanded, list), "Should return a list"
        assert len(expanded) >= 1, "Should have at least original query"
        assert query in expanded, "Should include original query"
        
        print(f"βœ“ Original: {query}")
        for i, q in enumerate(expanded[1:], 1):
            print(f"βœ“ Variation {i}: {q}")
        
        print("βœ… Query expansion test PASSED")
        return True
    except Exception as e:
        print(f"❌ Query expansion test FAILED: {e}")
        return False


def test_retriever_initialization():
    print("\n" + "="*60)
    print("TEST: Retriever Initialization")
    print("="*60)
    
    try:
        retriever = HybridRetriever(data_dir="data/test_retriever")
        
        assert retriever.embedder is not None, "Embedder should be initialized"
        assert retriever.reranker is not None, "Reranker should be initialized"
        assert retriever.chroma_client is not None, "ChromaDB client should be initialized"
        
        print("βœ“ Embedder loaded")
        print("βœ“ Reranker loaded")
        print("βœ“ ChromaDB client initialized")
        print("βœ… Retriever initialization test PASSED")
        return True
    except Exception as e:
        print(f"❌ Retriever initialization test FAILED: {e}")
        return False


def test_chunking():
    print("\n" + "="*60)
    print("TEST: Text Chunking")
    print("="*60)
    
    try:
        retriever = HybridRetriever(data_dir="data/test_chunking")
        
        text = " ".join([f"word{i}" for i in range(100)])
        chunks = retriever.chunk_text(text, chunk_size=20, overlap=5)
        
        assert isinstance(chunks, list), "Should return a list"
        assert len(chunks) > 0, "Should create at least one chunk"
        assert all(isinstance(c, str) for c in chunks), "All chunks should be strings"
        
        print(f"βœ“ Created {len(chunks)} chunks from {len(text)} character text")
        print(f"βœ“ First chunk: {len(chunks[0].split())} words")
        print("βœ… Chunking test PASSED")
        return True
    except Exception as e:
        print(f"❌ Chunking test FAILED: {e}")
        return False


def test_document_indexing():
    print("\n" + "="*60)
    print("TEST: Document Indexing")
    print("="*60)
    
    try:
        retriever = HybridRetriever(data_dir="data/test_indexing")
        
        test_docs = {
            "doc1": "Python is a high-level programming language. It is widely used for web development and data science.",
            "doc2": "Machine learning involves training models on data. It uses algorithms like neural networks.",
            "doc3": "FastAPI is a modern web framework for Python. It is fast and easy to use."
        }
        
        retriever.index_documents(test_docs, chunk_size=20, overlap=5)
        
        assert retriever.documents is not None, "Documents should be indexed"
        assert len(retriever.documents) > 0, "Should have indexed chunks"
        assert retriever.bm25 is not None, "BM25 index should be created"
        assert retriever.collection is not None, "ChromaDB collection should be created"
        
        print(f"βœ“ Indexed {len(test_docs)} documents")
        print(f"βœ“ Created {len(retriever.documents)} chunks")
        print("βœ“ BM25 index created")
        print("βœ“ Semantic index created")
        print("βœ… Document indexing test PASSED")
        return True
    except Exception as e:
        print(f"❌ Document indexing test FAILED: {e}")
        return False


def test_retrieval_methods():
    print("\n" + "="*60)
    print("TEST: Retrieval Methods")
    print("="*60)
    
    try:
        retriever = HybridRetriever(data_dir="data/test_methods")
        
        test_docs = {
            "doc1": "Python programming language for web development and machine learning applications",
            "doc2": "JavaScript is used for frontend development with React and Vue frameworks",
            "doc3": "SQL databases like PostgreSQL store structured data efficiently"
        }
        
        retriever.index_documents(test_docs, chunk_size=15, overlap=3)
        
        query = "Python programming"
        
        bm25_results = retriever.retrieve_bm25(query, top_k=2)
        assert isinstance(bm25_results, list), "BM25 should return a list"
        print(f"βœ“ BM25 retrieval: {len(bm25_results)} results")
        
        semantic_results = retriever.retrieve_semantic(query, top_k=2)
        assert isinstance(semantic_results, list), "Semantic should return a list"
        print(f"βœ“ Semantic retrieval: {len(semantic_results)} results")
        
        hybrid_results = retriever.retrieve_hybrid(query, top_k=2)
        assert isinstance(hybrid_results, list), "Hybrid should return a list"
        print(f"βœ“ Hybrid retrieval: {len(hybrid_results)} results")
        
        reranked = retriever.rerank(query, hybrid_results, top_k=1)
        assert isinstance(reranked, list), "Reranking should return a list"
        print(f"βœ“ Reranking: {len(reranked)} results")
        
        print("βœ… Retrieval methods test PASSED")
        return True
    except Exception as e:
        print(f"❌ Retrieval methods test FAILED: {e}")
        return False


def test_rag_system():
    print("\n" + "="*60)
    print("TEST: RAG System End-to-End")
    print("="*60)
    
    try:
        client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
        rag_system = RAGSystem(client, data_dir="data/test_rag")
        
        test_docs = {
            "summary": "I am an experienced AI engineer with 5 years of Python development",
            "projects": "Built RAG systems, multi-agent frameworks, and production ML pipelines",
            "stack": "Expert in Python, FastAPI, LangChain, ChromaDB, and OpenAI APIs"
        }
        
        rag_system.load_knowledge_base(test_docs, chunk_size=20, overlap=5)
        
        system_prompt = "Answer questions about professional background."
        response = rag_system.query(
            "What programming languages do you know?",
            system_prompt,
            method="hybrid",
            top_k=3
        )
        
        assert "answer" in response, "Response should contain answer"
        assert "context" in response, "Response should contain context"
        assert "method" in response, "Response should contain method"
        assert len(response["context"]) > 0, "Should retrieve some context"
        
        print(f"βœ“ Retrieved {len(response['context'])} context documents")
        print(f"βœ“ Generated answer: {len(response['answer'])} characters")
        print(f"βœ“ Method used: {response['method']}")
        print("βœ… RAG system test PASSED")
        return True
    except Exception as e:
        print(f"❌ RAG system test FAILED: {e}")
        return False


def run_all_tests():
    print("\n" + "="*70)
    print("RUNNING RAG SYSTEM TESTS")
    print("="*70)
    
    tests = [
        test_query_expansion,
        test_retriever_initialization,
        test_chunking,
        test_document_indexing,
        test_retrieval_methods,
        test_rag_system
    ]
    
    results = [test() for test in tests]
    
    print("\n" + "="*70)
    print(f"RESULTS: {sum(results)}/{len(results)} tests passed")
    print("="*70)
    
    return all(results)


if __name__ == "__main__":
    success = run_all_tests()
    sys.exit(0 if success else 1)