ai-engineering-project / scripts /test_e2e_pipeline.py
GitHub Action
Clean deployment without binary files
f884e6e
#!/usr/bin/env python3
"""
End-to-End Pipeline Test for HuggingFace CI/CD
This script tests the complete RAG pipeline with citation validation.
"""
# Add src to path
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
def test_citation_fix():
"""Test that the citation fix is working properly."""
print("πŸ§ͺ Testing Citation Fix...")
try:
from llm.prompt_templates import PromptTemplates # noqa: F401
# Test 1: Context formatting
mock_results = [
{
"content": "Remote work is allowed up to 3 days per week.",
"metadata": {"source_file": "remote_work_policy.md"},
"similarity_score": 0.89,
},
{
"content": "All employees must follow the code of conduct.",
"metadata": {"source_file": "employee_handbook.md"},
"similarity_score": 0.75,
},
]
formatted_context = PromptTemplates.format_context(mock_results)
# Verify the fix
assert "SOURCE FILE: remote_work_policy.md" in formatted_context
assert "SOURCE FILE: employee_handbook.md" in formatted_context
assert "Document 1:" not in formatted_context # Old format should be gone
print("βœ… Context formatting fix verified")
# Test 2: Citation extraction
test_response = "Based on the policy [Source: remote_work_policy.md], employees can work remotely."
citations = PromptTemplates.extract_citations(test_response)
assert len(citations) == 1
assert "remote_work_policy.md" in citations
print("βœ… Citation extraction working correctly")
# Test 3: System prompt contains fix
template = PromptTemplates.get_policy_qa_template()
assert "CRITICAL" in template.system_prompt
assert "exact filename" in template.system_prompt
assert "document_1.md" in template.system_prompt # Warning should be present
print("βœ… System prompt contains citation fix")
return True
except Exception as e:
print(f"❌ Citation fix test failed: {e}")
return False
def test_service_imports():
"""Test that all services can be imported."""
print("\nπŸ”§ Testing Service Imports...")
try:
# Test HF embedding service
from embedding.hf_embedding_service import HFEmbeddingService # noqa: F401
print("βœ… HF Embedding Service imported")
# Test prompt templates
from llm.prompt_templates import PromptTemplates # noqa: F401
print("βœ… Prompt Templates imported")
return True
except Exception as e:
print(f"❌ Service import test failed: {e}")
return False
def test_architecture_integration():
"""Test that the hybrid architecture components work together."""
print("\nπŸ—οΈ Testing Architecture Integration...")
try:
from llm.prompt_templates import PromptTemplates
# Test that we can create a complete prompt workflow
mock_search_results = [
{
"content": "Test policy content for integration test",
"metadata": {"source_file": "integration_test_policy.md"},
"similarity_score": 0.95,
}
]
# Format context
context = PromptTemplates.format_context(mock_search_results)
# Get template
template = PromptTemplates.get_policy_qa_template()
# Create user prompt
user_query = "What is the integration test policy?"
user_prompt = template.user_template.format(question=user_query, context=context)
# Verify complete prompt structure
assert "What is the integration test policy?" in user_prompt
assert "SOURCE FILE: integration_test_policy.md" in user_prompt
assert template.system_prompt is not None
print("βœ… Complete prompt workflow functional")
return True
except Exception as e:
print(f"❌ Architecture integration test failed: {e}")
return False
def main():
"""Run the end-to-end pipeline test."""
print("πŸš€ End-to-End Pipeline Test")
print("=" * 30)
tests = [
("Citation Fix", test_citation_fix),
("Service Imports", test_service_imports),
("Architecture Integration", test_architecture_integration),
]
passed = 0
total = len(tests)
for test_name, test_func in tests:
print(f"\nπŸ§ͺ Running: {test_name}")
if test_func():
passed += 1
else:
print(f"❌ {test_name} failed")
print("\n" + "=" * 30)
print(f"Pipeline Test Summary: {passed}/{total} passed")
if passed == total:
print("πŸŽ‰ End-to-end pipeline test successful!")
return 0
else:
print("⚠️ Some pipeline tests failed.")
return 1
if __name__ == "__main__":
exit_code = main()
sys.exit(exit_code)