Spaces:
Sleeping
Sleeping
File size: 5,019 Bytes
f884e6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
#!/usr/bin/env python3
"""
End-to-End Pipeline Test for HuggingFace CI/CD
This script tests the complete RAG pipeline with citation validation.
"""
# Add src to path
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
def test_citation_fix():
"""Test that the citation fix is working properly."""
print("π§ͺ Testing Citation Fix...")
try:
from llm.prompt_templates import PromptTemplates # noqa: F401
# Test 1: Context formatting
mock_results = [
{
"content": "Remote work is allowed up to 3 days per week.",
"metadata": {"source_file": "remote_work_policy.md"},
"similarity_score": 0.89,
},
{
"content": "All employees must follow the code of conduct.",
"metadata": {"source_file": "employee_handbook.md"},
"similarity_score": 0.75,
},
]
formatted_context = PromptTemplates.format_context(mock_results)
# Verify the fix
assert "SOURCE FILE: remote_work_policy.md" in formatted_context
assert "SOURCE FILE: employee_handbook.md" in formatted_context
assert "Document 1:" not in formatted_context # Old format should be gone
print("β
Context formatting fix verified")
# Test 2: Citation extraction
test_response = "Based on the policy [Source: remote_work_policy.md], employees can work remotely."
citations = PromptTemplates.extract_citations(test_response)
assert len(citations) == 1
assert "remote_work_policy.md" in citations
print("β
Citation extraction working correctly")
# Test 3: System prompt contains fix
template = PromptTemplates.get_policy_qa_template()
assert "CRITICAL" in template.system_prompt
assert "exact filename" in template.system_prompt
assert "document_1.md" in template.system_prompt # Warning should be present
print("β
System prompt contains citation fix")
return True
except Exception as e:
print(f"β Citation fix test failed: {e}")
return False
def test_service_imports():
"""Test that all services can be imported."""
print("\nπ§ Testing Service Imports...")
try:
# Test HF embedding service
from embedding.hf_embedding_service import HFEmbeddingService # noqa: F401
print("β
HF Embedding Service imported")
# Test prompt templates
from llm.prompt_templates import PromptTemplates # noqa: F401
print("β
Prompt Templates imported")
return True
except Exception as e:
print(f"β Service import test failed: {e}")
return False
def test_architecture_integration():
"""Test that the hybrid architecture components work together."""
print("\nποΈ Testing Architecture Integration...")
try:
from llm.prompt_templates import PromptTemplates
# Test that we can create a complete prompt workflow
mock_search_results = [
{
"content": "Test policy content for integration test",
"metadata": {"source_file": "integration_test_policy.md"},
"similarity_score": 0.95,
}
]
# Format context
context = PromptTemplates.format_context(mock_search_results)
# Get template
template = PromptTemplates.get_policy_qa_template()
# Create user prompt
user_query = "What is the integration test policy?"
user_prompt = template.user_template.format(question=user_query, context=context)
# Verify complete prompt structure
assert "What is the integration test policy?" in user_prompt
assert "SOURCE FILE: integration_test_policy.md" in user_prompt
assert template.system_prompt is not None
print("β
Complete prompt workflow functional")
return True
except Exception as e:
print(f"β Architecture integration test failed: {e}")
return False
def main():
"""Run the end-to-end pipeline test."""
print("π End-to-End Pipeline Test")
print("=" * 30)
tests = [
("Citation Fix", test_citation_fix),
("Service Imports", test_service_imports),
("Architecture Integration", test_architecture_integration),
]
passed = 0
total = len(tests)
for test_name, test_func in tests:
print(f"\nπ§ͺ Running: {test_name}")
if test_func():
passed += 1
else:
print(f"β {test_name} failed")
print("\n" + "=" * 30)
print(f"Pipeline Test Summary: {passed}/{total} passed")
if passed == total:
print("π End-to-end pipeline test successful!")
return 0
else:
print("β οΈ Some pipeline tests failed.")
return 1
if __name__ == "__main__":
exit_code = main()
sys.exit(exit_code)
|