File size: 1,845 Bytes
3998131 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
"""
Test script for RAG Chain
Runs end-to-end tests with sample queries
"""
import os
import logging
from dotenv import load_dotenv
# Load env vars before importing modules that might use them
load_dotenv()
from module_a.rag_chain import LegalRAGChain
from module_a.config import LOG_LEVEL, LOG_FORMAT
# Configure logging
logging.basicConfig(level=LOG_LEVEL, format=LOG_FORMAT)
logger = logging.getLogger(__name__)
def main():
"""Run RAG chain tests"""
print("=" * 80)
print("Testing Nepal Justice Weaver - RAG Chain")
print("=" * 80)
# Check API key
api_key = os.getenv("MISTRAL_API_KEY")
if not api_key:
print("\n✗ Error: MISTRAL_API_KEY not found!")
print("Please set it in .env file or environment variable.")
print("Example: export MISTRAL_API_KEY='your_key_here'")
return
try:
# Initialize chain
print("\nInitializing RAG Chain...")
rag = LegalRAGChain()
# Test queries
test_queries = [
"I am a single mother, how to get citizenship for my child?",
"Can daughters inherit property like sons?",
"What are the fundamental rights regarding equality?"
]
for query in test_queries:
print(f"\n\n{'=' * 80}")
print(f"QUERY: {query}")
print(f"{'=' * 80}")
result = rag.run(query)
print(f"\nEXPLANATION:\n{result['explanation']}")
print("\nSOURCES:")
for source in result['sources']:
print(f"- {source['section']} ({source['file']})")
except Exception as e:
logger.error(f"Test failed: {e}", exc_info=True)
print(f"\n✗ Test failed: {e}")
if __name__ == "__main__":
main()
|