Apurva Umredkar
added backend functionality
d8f06d4
raw
history blame
7.96 kB
import os
import pickle
import json
import numpy as np
from typing import List, Dict, Any, Optional, Tuple
import faiss
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, CrossEncoder
class VectorStore:
def __init__(self,
embedding_dir: str = "data/embeddings",
model_name: str = "BAAI/bge-small-en-v1.5",
reranker_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
self.embedding_dir = embedding_dir
self.index = None
self.chunk_ids = []
self.chunks = {}
# Load embedding model
self.model = SentenceTransformer(model_name)
# Load reranker model
self.reranker = CrossEncoder(reranker_name)
# Load or create index
self.load_or_create_index()
def load_or_create_index(self) -> None:
"""Load existing index or create a new one."""
index_path = os.path.join(self.embedding_dir, 'faiss_index.pkl')
if os.path.exists(index_path):
# Load existing index
with open(index_path, 'rb') as f:
data = pickle.load(f)
self.index = data['index']
self.chunk_ids = data['chunk_ids']
self.chunks = data['chunks']
print(f"Loaded existing index with {len(self.chunk_ids)} chunks")
else:
# Create new index
embeddings_path = os.path.join(self.embedding_dir, 'embeddings.pkl')
if os.path.exists(embeddings_path):
self.create_index()
else:
print("No embeddings found. Please run the chunker first.")
def create_index(self) -> None:
"""Create FAISS index from embeddings."""
embeddings_path = os.path.join(self.embedding_dir, 'embeddings.pkl')
with open(embeddings_path, 'rb') as f:
embedding_map = pickle.load(f)
# Extract embeddings and chunk IDs
chunk_ids = list(embedding_map.keys())
embeddings = np.array([embedding_map[chunk_id]['embedding'] for chunk_id in chunk_ids])
chunks = {chunk_id: embedding_map[chunk_id]['chunk'] for chunk_id in chunk_ids}
# Create FAISS index
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings.astype(np.float32))
# Save index and metadata
self.index = index
self.chunk_ids = chunk_ids
self.chunks = chunks
# Save to disk
with open(os.path.join(self.embedding_dir, 'faiss_index.pkl'), 'wb') as f:
pickle.dump({
'index': index,
'chunk_ids': chunk_ids,
'chunks': chunks
}, f)
print(f"Created index with {len(chunk_ids)} chunks")
def search(self,
query: str,
k: int = 5,
filter_categories: Optional[List[str]] = None,
rerank: bool = True) -> List[Dict[str, Any]]:
"""Search for relevant chunks."""
if self.index is None:
print("No index available. Please create an index first.")
return []
# Create query embedding
query_embedding = self.model.encode([query])[0]
# Search index
D, I = self.index.search(np.array([query_embedding]).astype(np.float32), min(k * 2, len(self.chunk_ids)))
# Get results
results = []
for i, idx in enumerate(I[0]):
chunk_id = self.chunk_ids[idx]
chunk = self.chunks[chunk_id]
# Apply category filter if specified
if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories):
continue
result = {
'chunk_id': chunk_id,
'score': float(D[0][i]),
'chunk': chunk
}
results.append(result)
# Rerank results if requested
if rerank and results:
# Prepare pairs for reranking
pairs = [(query, result['chunk']['content']) for result in results]
# Get reranking scores
rerank_scores = self.reranker.predict(pairs)
# Update scores and sort
for i, score in enumerate(rerank_scores):
results[i]['rerank_score'] = float(score)
# Sort by rerank score
results = sorted(results, key=lambda x: x['rerank_score'], reverse=True)
# Limit to k results
results = results[:k]
return results
def hybrid_search(self,
query: str,
k: int = 5,
filter_categories: Optional[List[str]] = None) -> List[Dict[str, Any]]:
"""Combine dense vector search with BM25-style keyword matching."""
# Get vector search results
vector_results = self.search(query, k=k, filter_categories=filter_categories, rerank=False)
# Simple keyword matching (simulating BM25)
keywords = query.lower().split()
# Score all chunks by keyword presence
keyword_scores = {}
for chunk_id, chunk_data in self.chunks.items():
chunk = chunk_data
content = (chunk['title'] + " " + chunk['content']).lower()
# Count keyword matches
score = sum(content.count(keyword) for keyword in keywords)
# Apply category filter if specified
if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories):
continue
keyword_scores[chunk_id] = score
# Get top keyword matches
keyword_results = sorted(
[{'chunk_id': chunk_id, 'score': score, 'chunk': self.chunks[chunk_id]}
for chunk_id, score in keyword_scores.items() if score > 0],
key=lambda x: x['score'],
reverse=True
)[:k]
# Combine results (remove duplicates)
seen_ids = set()
combined_results = []
# Add vector results first
for result in vector_results:
combined_results.append(result)
seen_ids.add(result['chunk_id'])
# Add keyword results if not already added
for result in keyword_results:
if result['chunk_id'] not in seen_ids:
combined_results.append(result)
seen_ids.add(result['chunk_id'])
# Limit to k results
combined_results = combined_results[:k]
# Rerank final results
if combined_results:
# Prepare pairs for reranking
pairs = [(query, result['chunk']['content']) for result in combined_results]
# Get reranking scores
rerank_scores = self.reranker.predict(pairs)
# Update scores and sort
for i, score in enumerate(rerank_scores):
combined_results[i]['rerank_score'] = float(score)
# Sort by rerank score
combined_results = sorted(combined_results, key=lambda x: x['rerank_score'], reverse=True)
return combined_results
# Example usage
if __name__ == "__main__":
vector_store = VectorStore()
results = vector_store.hybrid_search("How do I apply for OPT?")
print(f"Found {len(results)} results")
for i, result in enumerate(results[:3]):
print(f"Result {i+1}: {result['chunk']['title']}")
print(f"Score: {result.get('rerank_score', result['score'])}")
print(f"Content: {result['chunk']['content'][:100]}...")
print()