Spaces:
Running
Running
| import sys | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from openai import OpenAI | |
| from dotenv import load_dotenv | |
| import os | |
| load_dotenv(Path(__file__).parent.parent.parent.parent.parent / ".env", override=True) | |
| from rag_system import QueryExpander, HybridRetriever, RAGSystem | |
| def test_query_expansion(): | |
| print("\n" + "="*60) | |
| print("TEST: Query Expansion") | |
| print("="*60) | |
| try: | |
| client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| expander = QueryExpander(client) | |
| query = "What are your programming skills?" | |
| expanded = expander.expand_query(query, num_variations=2) | |
| assert isinstance(expanded, list), "Should return a list" | |
| assert len(expanded) >= 1, "Should have at least original query" | |
| assert query in expanded, "Should include original query" | |
| print(f"β Original: {query}") | |
| for i, q in enumerate(expanded[1:], 1): | |
| print(f"β Variation {i}: {q}") | |
| print("β Query expansion test PASSED") | |
| return True | |
| except Exception as e: | |
| print(f"β Query expansion test FAILED: {e}") | |
| return False | |
| def test_retriever_initialization(): | |
| print("\n" + "="*60) | |
| print("TEST: Retriever Initialization") | |
| print("="*60) | |
| try: | |
| retriever = HybridRetriever(data_dir="data/test_retriever") | |
| assert retriever.embedder is not None, "Embedder should be initialized" | |
| assert retriever.reranker is not None, "Reranker should be initialized" | |
| assert retriever.chroma_client is not None, "ChromaDB client should be initialized" | |
| print("β Embedder loaded") | |
| print("β Reranker loaded") | |
| print("β ChromaDB client initialized") | |
| print("β Retriever initialization test PASSED") | |
| return True | |
| except Exception as e: | |
| print(f"β Retriever initialization test FAILED: {e}") | |
| return False | |
| def test_chunking(): | |
| print("\n" + "="*60) | |
| print("TEST: Text Chunking") | |
| print("="*60) | |
| try: | |
| retriever = HybridRetriever(data_dir="data/test_chunking") | |
| text = " ".join([f"word{i}" for i in range(100)]) | |
| chunks = retriever.chunk_text(text, chunk_size=20, overlap=5) | |
| assert isinstance(chunks, list), "Should return a list" | |
| assert len(chunks) > 0, "Should create at least one chunk" | |
| assert all(isinstance(c, str) for c in chunks), "All chunks should be strings" | |
| print(f"β Created {len(chunks)} chunks from {len(text)} character text") | |
| print(f"β First chunk: {len(chunks[0].split())} words") | |
| print("β Chunking test PASSED") | |
| return True | |
| except Exception as e: | |
| print(f"β Chunking test FAILED: {e}") | |
| return False | |
| def test_document_indexing(): | |
| print("\n" + "="*60) | |
| print("TEST: Document Indexing") | |
| print("="*60) | |
| try: | |
| retriever = HybridRetriever(data_dir="data/test_indexing") | |
| test_docs = { | |
| "doc1": "Python is a high-level programming language. It is widely used for web development and data science.", | |
| "doc2": "Machine learning involves training models on data. It uses algorithms like neural networks.", | |
| "doc3": "FastAPI is a modern web framework for Python. It is fast and easy to use." | |
| } | |
| retriever.index_documents(test_docs, chunk_size=20, overlap=5) | |
| assert retriever.documents is not None, "Documents should be indexed" | |
| assert len(retriever.documents) > 0, "Should have indexed chunks" | |
| assert retriever.bm25 is not None, "BM25 index should be created" | |
| assert retriever.collection is not None, "ChromaDB collection should be created" | |
| print(f"β Indexed {len(test_docs)} documents") | |
| print(f"β Created {len(retriever.documents)} chunks") | |
| print("β BM25 index created") | |
| print("β Semantic index created") | |
| print("β Document indexing test PASSED") | |
| return True | |
| except Exception as e: | |
| print(f"β Document indexing test FAILED: {e}") | |
| return False | |
| def test_retrieval_methods(): | |
| print("\n" + "="*60) | |
| print("TEST: Retrieval Methods") | |
| print("="*60) | |
| try: | |
| retriever = HybridRetriever(data_dir="data/test_methods") | |
| test_docs = { | |
| "doc1": "Python programming language for web development and machine learning applications", | |
| "doc2": "JavaScript is used for frontend development with React and Vue frameworks", | |
| "doc3": "SQL databases like PostgreSQL store structured data efficiently" | |
| } | |
| retriever.index_documents(test_docs, chunk_size=15, overlap=3) | |
| query = "Python programming" | |
| bm25_results = retriever.retrieve_bm25(query, top_k=2) | |
| assert isinstance(bm25_results, list), "BM25 should return a list" | |
| print(f"β BM25 retrieval: {len(bm25_results)} results") | |
| semantic_results = retriever.retrieve_semantic(query, top_k=2) | |
| assert isinstance(semantic_results, list), "Semantic should return a list" | |
| print(f"β Semantic retrieval: {len(semantic_results)} results") | |
| hybrid_results = retriever.retrieve_hybrid(query, top_k=2) | |
| assert isinstance(hybrid_results, list), "Hybrid should return a list" | |
| print(f"β Hybrid retrieval: {len(hybrid_results)} results") | |
| reranked = retriever.rerank(query, hybrid_results, top_k=1) | |
| assert isinstance(reranked, list), "Reranking should return a list" | |
| print(f"β Reranking: {len(reranked)} results") | |
| print("β Retrieval methods test PASSED") | |
| return True | |
| except Exception as e: | |
| print(f"β Retrieval methods test FAILED: {e}") | |
| return False | |
| def test_rag_system(): | |
| print("\n" + "="*60) | |
| print("TEST: RAG System End-to-End") | |
| print("="*60) | |
| try: | |
| client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| rag_system = RAGSystem(client, data_dir="data/test_rag") | |
| test_docs = { | |
| "summary": "I am an experienced AI engineer with 5 years of Python development", | |
| "projects": "Built RAG systems, multi-agent frameworks, and production ML pipelines", | |
| "stack": "Expert in Python, FastAPI, LangChain, ChromaDB, and OpenAI APIs" | |
| } | |
| rag_system.load_knowledge_base(test_docs, chunk_size=20, overlap=5) | |
| system_prompt = "Answer questions about professional background." | |
| response = rag_system.query( | |
| "What programming languages do you know?", | |
| system_prompt, | |
| method="hybrid", | |
| top_k=3 | |
| ) | |
| assert "answer" in response, "Response should contain answer" | |
| assert "context" in response, "Response should contain context" | |
| assert "method" in response, "Response should contain method" | |
| assert len(response["context"]) > 0, "Should retrieve some context" | |
| print(f"β Retrieved {len(response['context'])} context documents") | |
| print(f"β Generated answer: {len(response['answer'])} characters") | |
| print(f"β Method used: {response['method']}") | |
| print("β RAG system test PASSED") | |
| return True | |
| except Exception as e: | |
| print(f"β RAG system test FAILED: {e}") | |
| return False | |
| def run_all_tests(): | |
| print("\n" + "="*70) | |
| print("RUNNING RAG SYSTEM TESTS") | |
| print("="*70) | |
| tests = [ | |
| test_query_expansion, | |
| test_retriever_initialization, | |
| test_chunking, | |
| test_document_indexing, | |
| test_retrieval_methods, | |
| test_rag_system | |
| ] | |
| results = [test() for test in tests] | |
| print("\n" + "="*70) | |
| print(f"RESULTS: {sum(results)}/{len(results)} tests passed") | |
| print("="*70) | |
| return all(results) | |
| if __name__ == "__main__": | |
| success = run_all_tests() | |
| sys.exit(0 if success else 1) | |