Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Test retrieval quality independently. | |
| Useful for debugging and tuning retrieval parameters. | |
| """ | |
| import logging | |
| import sys | |
| from pathlib import Path | |
| # Add parent directory to path | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from src import create_retriever | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| TEST_QUERIES = [ | |
| "How do I create a router in FastAPI?", | |
| "What are dependencies?", | |
| "How do I handle errors?", | |
| "Show me authentication examples", | |
| "How do I validate request bodies?", | |
| ] | |
| def test_retrieval(): | |
| """Test retrieval with various queries.""" | |
| logger.info("=" * 60) | |
| logger.info("Retrieval Quality Test") | |
| logger.info("=" * 60) | |
| # Initialize retriever | |
| retriever = create_retriever() | |
| stats = retriever.get_collection_stats() | |
| logger.info(f"\nVector Database Stats:") | |
| logger.info(f" Total chunks: {stats['total_chunks']}") | |
| logger.info(f" Collection: {stats['collection_name']}") | |
| logger.info(f" Embedding dim: {stats['embedding_dimension']}") | |
| # Test each query | |
| for i, query in enumerate(TEST_QUERIES, 1): | |
| logger.info("\n" + "=" * 60) | |
| logger.info(f"Test {i}/{len(TEST_QUERIES)}") | |
| logger.info("=" * 60) | |
| logger.info(f"Query: {query}") | |
| # Retrieve | |
| results = retriever.retrieve(query, top_k=3) | |
| logger.info(f"\nFound {len(results)} results:") | |
| for j, result in enumerate(results, 1): | |
| logger.info(f"\n--- Result {j} ---") | |
| logger.info(f"Score: {result['score']:.4f}") | |
| logger.info(f"Source: {result['metadata'].get('title', 'Unknown')}") | |
| logger.info(f"Section: {result['metadata'].get('section', 'Unknown')}") | |
| logger.info(f"Content preview:") | |
| logger.info(f"{result['content'][:200]}...") | |
| # Quality check | |
| avg_score = sum(r['score'] for r in results) / len(results) if results else 0 | |
| logger.info(f"\nAverage relevance score: {avg_score:.4f}") | |
| if avg_score >= 0.75: | |
| logger.info("✓ High quality results") | |
| elif avg_score >= 0.6: | |
| logger.info("⚠ Medium quality results") | |
| else: | |
| logger.info("✗ Low quality results - consider tuning") | |
| logger.info("\n" + "=" * 60) | |
| logger.info("Retrieval test complete") | |
| logger.info("=" * 60) | |
| if __name__ == "__main__": | |
| test_retrieval() | |