Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| End-to-End Pipeline Test for HuggingFace CI/CD | |
| This script tests the complete RAG pipeline with citation validation. | |
| """ | |
| # Add src to path | |
| import os | |
| import sys | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) | |
| def test_citation_fix(): | |
| """Test that the citation fix is working properly.""" | |
| print("π§ͺ Testing Citation Fix...") | |
| try: | |
| from llm.prompt_templates import PromptTemplates # noqa: F401 | |
| # Test 1: Context formatting | |
| mock_results = [ | |
| { | |
| "content": "Remote work is allowed up to 3 days per week.", | |
| "metadata": {"source_file": "remote_work_policy.md"}, | |
| "similarity_score": 0.89, | |
| }, | |
| { | |
| "content": "All employees must follow the code of conduct.", | |
| "metadata": {"source_file": "employee_handbook.md"}, | |
| "similarity_score": 0.75, | |
| }, | |
| ] | |
| formatted_context = PromptTemplates.format_context(mock_results) | |
| # Verify the fix | |
| assert "SOURCE FILE: remote_work_policy.md" in formatted_context | |
| assert "SOURCE FILE: employee_handbook.md" in formatted_context | |
| assert "Document 1:" not in formatted_context # Old format should be gone | |
| print("β Context formatting fix verified") | |
| # Test 2: Citation extraction | |
| test_response = "Based on the policy [Source: remote_work_policy.md], employees can work remotely." | |
| citations = PromptTemplates.extract_citations(test_response) | |
| assert len(citations) == 1 | |
| assert "remote_work_policy.md" in citations | |
| print("β Citation extraction working correctly") | |
| # Test 3: System prompt contains fix | |
| template = PromptTemplates.get_policy_qa_template() | |
| assert "CRITICAL" in template.system_prompt | |
| assert "exact filename" in template.system_prompt | |
| assert "document_1.md" in template.system_prompt # Warning should be present | |
| print("β System prompt contains citation fix") | |
| return True | |
| except Exception as e: | |
| print(f"β Citation fix test failed: {e}") | |
| return False | |
| def test_service_imports(): | |
| """Test that all services can be imported.""" | |
| print("\nπ§ Testing Service Imports...") | |
| try: | |
| # Test HF embedding service | |
| from embedding.hf_embedding_service import HFEmbeddingService # noqa: F401 | |
| print("β HF Embedding Service imported") | |
| # Test prompt templates | |
| from llm.prompt_templates import PromptTemplates # noqa: F401 | |
| print("β Prompt Templates imported") | |
| return True | |
| except Exception as e: | |
| print(f"β Service import test failed: {e}") | |
| return False | |
| def test_architecture_integration(): | |
| """Test that the hybrid architecture components work together.""" | |
| print("\nποΈ Testing Architecture Integration...") | |
| try: | |
| from llm.prompt_templates import PromptTemplates | |
| # Test that we can create a complete prompt workflow | |
| mock_search_results = [ | |
| { | |
| "content": "Test policy content for integration test", | |
| "metadata": {"source_file": "integration_test_policy.md"}, | |
| "similarity_score": 0.95, | |
| } | |
| ] | |
| # Format context | |
| context = PromptTemplates.format_context(mock_search_results) | |
| # Get template | |
| template = PromptTemplates.get_policy_qa_template() | |
| # Create user prompt | |
| user_query = "What is the integration test policy?" | |
| user_prompt = template.user_template.format(question=user_query, context=context) | |
| # Verify complete prompt structure | |
| assert "What is the integration test policy?" in user_prompt | |
| assert "SOURCE FILE: integration_test_policy.md" in user_prompt | |
| assert template.system_prompt is not None | |
| print("β Complete prompt workflow functional") | |
| return True | |
| except Exception as e: | |
| print(f"β Architecture integration test failed: {e}") | |
| return False | |
| def main(): | |
| """Run the end-to-end pipeline test.""" | |
| print("π End-to-End Pipeline Test") | |
| print("=" * 30) | |
| tests = [ | |
| ("Citation Fix", test_citation_fix), | |
| ("Service Imports", test_service_imports), | |
| ("Architecture Integration", test_architecture_integration), | |
| ] | |
| passed = 0 | |
| total = len(tests) | |
| for test_name, test_func in tests: | |
| print(f"\nπ§ͺ Running: {test_name}") | |
| if test_func(): | |
| passed += 1 | |
| else: | |
| print(f"β {test_name} failed") | |
| print("\n" + "=" * 30) | |
| print(f"Pipeline Test Summary: {passed}/{total} passed") | |
| if passed == total: | |
| print("π End-to-end pipeline test successful!") | |
| return 0 | |
| else: | |
| print("β οΈ Some pipeline tests failed.") | |
| return 1 | |
| if __name__ == "__main__": | |
| exit_code = main() | |
| sys.exit(exit_code) | |