File size: 2,286 Bytes
e0ee929
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
Top-k semantic retrieval against FAISS index.
"""
from typing import List, Dict, Callable
import numpy as np
import faiss


def retrieve(query: str, index, embed_fn: Callable, metadata_df, chunks_lookup: dict = None, k: int = 5) -> List[Dict]:
    """
    Return top-k results with text and metadata.

    Args:
        query: Query string
        index: FAISS index
        embed_fn: Function that takes a string and returns a normalized embedding (numpy array)
        metadata_df: DataFrame with metadata (chunk_id, book, para_idx_start, para_idx_end, char_count)
        chunks_lookup: Optional dict mapping chunk_id to chunk dict with 'text' field
        k: Number of results to return

    Returns:
        List of dicts: {score, text, meta:{...}, chunk_id} length == k.
    """
    # Embed the query using the provided function
    query_embedding = embed_fn(query)
    
    # Ensure query embedding is the right shape and type
    if len(query_embedding.shape) == 1:
        query_embedding = query_embedding.reshape(1, -1)
    if query_embedding.dtype != np.float32:
        query_embedding = query_embedding.astype(np.float32)
    
    # Search FAISS index
    scores, indices = index.search(query_embedding, k)
    
    # Map indices to metadata and return results
    results = []
    for score, idx in zip(scores[0], indices[0]):
        if idx < 0 or idx >= len(metadata_df):
            continue  # Skip invalid indices
        
        row = metadata_df.iloc[idx]
        chunk_id = row['chunk_id']
        
        # Get text from chunks_lookup if available, otherwise use placeholder
        text = ""
        if chunks_lookup and chunk_id in chunks_lookup:
            text = chunks_lookup[chunk_id].get('text', '')
        elif 'text' in row:
            text = row['text']
        else:
            text = f"[Chunk {chunk_id} - text not available]"
        
        results.append({
            'score': float(score),
            'text': text,
            'chunk_id': chunk_id,
            'meta': {
                'book': row['book'],
                'para_idx_start': int(row['para_idx_start']),
                'para_idx_end': int(row['para_idx_end']),
                'char_count': int(row['char_count'])
            }
        })
    
    return results