File size: 2,574 Bytes
14f13a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#!/usr/bin/env python3
"""
Test retrieval quality independently.

Useful for debugging and tuning retrieval parameters.
"""
import logging
import sys
from pathlib import Path

# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))

from src import create_retriever

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


TEST_QUERIES = [
    "How do I create a router in FastAPI?",
    "What are dependencies?",
    "How do I handle errors?",
    "Show me authentication examples",
    "How do I validate request bodies?",
]


def test_retrieval():
    """Test retrieval with various queries."""
    
    logger.info("=" * 60)
    logger.info("Retrieval Quality Test")
    logger.info("=" * 60)
    
    # Initialize retriever
    retriever = create_retriever()
    stats = retriever.get_collection_stats()
    
    logger.info(f"\nVector Database Stats:")
    logger.info(f"  Total chunks: {stats['total_chunks']}")
    logger.info(f"  Collection: {stats['collection_name']}")
    logger.info(f"  Embedding dim: {stats['embedding_dimension']}")
    
    # Test each query
    for i, query in enumerate(TEST_QUERIES, 1):
        logger.info("\n" + "=" * 60)
        logger.info(f"Test {i}/{len(TEST_QUERIES)}")
        logger.info("=" * 60)
        logger.info(f"Query: {query}")
        
        # Retrieve
        results = retriever.retrieve(query, top_k=3)
        
        logger.info(f"\nFound {len(results)} results:")
        
        for j, result in enumerate(results, 1):
            logger.info(f"\n--- Result {j} ---")
            logger.info(f"Score: {result['score']:.4f}")
            logger.info(f"Source: {result['metadata'].get('title', 'Unknown')}")
            logger.info(f"Section: {result['metadata'].get('section', 'Unknown')}")
            logger.info(f"Content preview:")
            logger.info(f"{result['content'][:200]}...")
        
        # Quality check
        avg_score = sum(r['score'] for r in results) / len(results) if results else 0
        logger.info(f"\nAverage relevance score: {avg_score:.4f}")
        
        if avg_score >= 0.75:
            logger.info("✓ High quality results")
        elif avg_score >= 0.6:
            logger.info("⚠ Medium quality results")
        else:
            logger.info("✗ Low quality results - consider tuning")
    
    logger.info("\n" + "=" * 60)
    logger.info("Retrieval test complete")
    logger.info("=" * 60)


if __name__ == "__main__":
    test_retrieval()