caps-chatbot-internal / retrieval.py
atwine's picture
Replace FAISS index with portable numpy embeddings in pkl
543fbbd
"""
retrieval.py
------------
Sanyu RAG β€” Retrieval Module
Loads pre-computed L2-normalised numpy embeddings from the .pkl file and
performs retrieval via a simple dot-product similarity search (pure numpy).
No FAISS dependency at runtime β€” avoids FAISS SWIG binary incompatibilities
between build environments (Windows vs Linux HF Space).
The .pkl file is located at: data/sanyu_knowledge_base.pkl (hardcoded as agreed).
"""
import pickle
import numpy as np
from sentence_transformers import SentenceTransformer
# Path to the serialised knowledge base β€” hardcoded as confirmed with Atwine.
DEFAULT_PKL_PATH = 'data/sanyu_knowledge_base.pkl'
def load_index(pkl_path: str = DEFAULT_PKL_PATH) -> tuple:
"""
Loads the embeddings and chunk metadata from the .pkl file.
Returns:
(embeddings, chunks, model_name)
- embeddings: np.ndarray float32, shape (n, d), L2-normalised
- chunks: list of chunk dicts (text + metadata)
- model_name: str, the embedding model used to build the index
"""
with open(pkl_path, 'rb') as f:
payload = pickle.load(f)
return payload['embeddings'], payload['chunks'], payload['embedding_model']
def retrieve(query: str,
embeddings: np.ndarray,
chunks: list,
model: SentenceTransformer,
top_k: int = 4) -> list:
"""
Retrieves the top_k most relevant chunks for a given query.
Uses a numpy dot product against pre-computed L2-normalised embeddings
(equivalent to cosine similarity). No FAISS required at runtime.
Args:
query: The user's input string.
embeddings: np.ndarray of shape (n, d), L2-normalised chunk embeddings.
chunks: The list of chunk dicts corresponding to the embeddings.
model: A loaded SentenceTransformer model instance.
top_k: Number of results to return (default 4).
Returns:
List of chunk dicts with an added 'similarity_score' key,
ordered from most to least relevant.
"""
query_embedding = model.encode([query], normalize_embeddings=True)
query_embedding = np.array(query_embedding, dtype='float32') # shape (1, d)
# Cosine similarity via dot product (both sides are L2-normalised)
scores = (query_embedding @ embeddings.T).flatten() # shape (n,)
# Get top_k indices sorted by descending score
top_indices = np.argsort(scores)[::-1][:top_k]
results = []
for idx in top_indices:
chunk = chunks[int(idx)].copy()
chunk['similarity_score'] = float(scores[idx])
results.append(chunk)
return results