Spaces:
Sleeping
Sleeping
| """test_rag.py module.""" | |
| import pytest | |
| from src.rag.retriever import FinancialDataRetriever | |
| from src.rag.generator import RAGGenerator | |
| import yaml | |
| def rag_config(): | |
| with open('config/server_config.yaml', 'r') as f: | |
| config = yaml.safe_load(f) | |
| config['rag'] = { | |
| 'retriever': 'faiss', | |
| 'max_documents': 5, | |
| 'similarity_threshold': 0.7 | |
| } | |
| return config | |
| def retriever(rag_config): | |
| return FinancialDataRetriever(rag_config) | |
| def generator(rag_config): | |
| return RAGGenerator(rag_config) | |
| def test_retriever_initialization(retriever, rag_config): | |
| assert retriever.retriever_type == rag_config['rag']['retriever'] | |
| assert retriever.max_documents == rag_config['rag']['max_documents'] | |
| def test_document_indexing(retriever): | |
| test_documents = [ | |
| {'text': 'Financial report 2023', 'id': 1}, | |
| {'text': 'Market analysis Q4', 'id': 2}, | |
| {'text': 'Investment strategy', 'id': 3} | |
| ] | |
| retriever.index_documents(test_documents) | |
| assert retriever.index.ntotal == len(test_documents) | |
| def test_document_retrieval(retriever): | |
| # Index test documents | |
| test_documents = [ | |
| {'text': 'Financial report 2023', 'id': 1}, | |
| {'text': 'Market analysis Q4', 'id': 2}, | |
| {'text': 'Investment strategy', 'id': 3} | |
| ] | |
| retriever.index_documents(test_documents) | |
| # Test retrieval | |
| query = "financial report" | |
| results = retriever.retrieve(query) | |
| assert len(results) > 0 | |
| assert all('document' in result for result in results) | |
| assert all('score' in result for result in results) | |
| def test_generator_initialization(generator): | |
| assert hasattr(generator, 'model') | |
| assert hasattr(generator, 'tokenizer') | |
| def test_text_generation(generator): | |
| retrieved_docs = [ | |
| { | |
| 'document': {'text': 'Financial market analysis shows positive trends'}, | |
| 'score': 0.9 | |
| } | |
| ] | |
| generated_text = generator.generate( | |
| query="Summarize market trends", | |
| retrieved_docs=retrieved_docs | |
| ) | |
| assert isinstance(generated_text, str) | |
| assert len(generated_text) > 0 | |
| def test_context_preparation(generator): | |
| retrieved_docs = [ | |
| { | |
| 'document': {'text': 'Doc 1 content'}, | |
| 'score': 0.9 | |
| }, | |
| { | |
| 'document': {'text': 'Doc 2 content'}, | |
| 'score': 0.8 | |
| } | |
| ] | |
| context = generator.prepare_context(retrieved_docs) | |
| assert isinstance(context, str) | |
| assert 'Doc 1 content' in context | |
| assert 'Doc 2 content' in context | |