DeveloperDocs_RAG / test_retrieval.py
Aishwarya30998's picture
Deploy DeveloperDocs-AI-Copilot-RAG to Hugging Face Space
14f13a5
#!/usr/bin/env python3
"""
Test retrieval quality independently.
Useful for debugging and tuning retrieval parameters.
"""
import logging
import sys
from pathlib import Path
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from src import create_retriever
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
TEST_QUERIES = [
"How do I create a router in FastAPI?",
"What are dependencies?",
"How do I handle errors?",
"Show me authentication examples",
"How do I validate request bodies?",
]
def test_retrieval():
"""Test retrieval with various queries."""
logger.info("=" * 60)
logger.info("Retrieval Quality Test")
logger.info("=" * 60)
# Initialize retriever
retriever = create_retriever()
stats = retriever.get_collection_stats()
logger.info(f"\nVector Database Stats:")
logger.info(f" Total chunks: {stats['total_chunks']}")
logger.info(f" Collection: {stats['collection_name']}")
logger.info(f" Embedding dim: {stats['embedding_dimension']}")
# Test each query
for i, query in enumerate(TEST_QUERIES, 1):
logger.info("\n" + "=" * 60)
logger.info(f"Test {i}/{len(TEST_QUERIES)}")
logger.info("=" * 60)
logger.info(f"Query: {query}")
# Retrieve
results = retriever.retrieve(query, top_k=3)
logger.info(f"\nFound {len(results)} results:")
for j, result in enumerate(results, 1):
logger.info(f"\n--- Result {j} ---")
logger.info(f"Score: {result['score']:.4f}")
logger.info(f"Source: {result['metadata'].get('title', 'Unknown')}")
logger.info(f"Section: {result['metadata'].get('section', 'Unknown')}")
logger.info(f"Content preview:")
logger.info(f"{result['content'][:200]}...")
# Quality check
avg_score = sum(r['score'] for r in results) / len(results) if results else 0
logger.info(f"\nAverage relevance score: {avg_score:.4f}")
if avg_score >= 0.75:
logger.info("✓ High quality results")
elif avg_score >= 0.6:
logger.info("⚠ Medium quality results")
else:
logger.info("✗ Low quality results - consider tuning")
logger.info("\n" + "=" * 60)
logger.info("Retrieval test complete")
logger.info("=" * 60)
if __name__ == "__main__":
test_retrieval()