AUMREDKA's picture
Update buffalo_rag/vector_store/db.py
999388b verified
raw
history blame
6.1 kB
import os
import pickle
import json
import numpy as np
from typing import List, Dict, Any, Optional, Tuple
import faiss
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, CrossEncoder
class VectorStore:
def __init__(self,
embedding_dir: str = "data/embeddings",
model_name: str = "BAAI/bge-small-en-v1.5",
reranker_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
self.embedding_dir = embedding_dir
self.index = None
self.chunk_ids = []
self.chunks = {}
self.model = SentenceTransformer(model_name)
self.reranker = CrossEncoder(reranker_name)
self.load_or_create_index()
def load_or_create_index(self) -> None:
index_path = os.path.join(self.embedding_dir, 'faiss_index.pkl')
if os.path.exists(index_path):
with open(index_path, 'rb') as f:
data = pickle.load(f)
self.index = data['index']
self.chunk_ids = data['chunk_ids']
self.chunks = data['chunks']
print(f"Loaded existing index with {len(self.chunk_ids)} chunks")
else:
embeddings_path = os.path.join(self.embedding_dir, 'embeddings.pkl')
if os.path.exists(embeddings_path):
self.create_index()
else:
print("No embeddings found. Please run the chunker first.")
def create_index(self) -> None:
"""Create FAISS index from embeddings."""
embeddings_path = os.path.join(self.embedding_dir, 'embeddings.pkl')
with open(embeddings_path, 'rb') as f:
embedding_map = pickle.load(f)
chunk_ids = list(embedding_map.keys())
embeddings = np.array([embedding_map[chunk_id]['embedding'] for chunk_id in chunk_ids])
chunks = {chunk_id: embedding_map[chunk_id]['chunk'] for chunk_id in chunk_ids}
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings.astype(np.float32))
self.index = index
self.chunk_ids = chunk_ids
self.chunks = chunks
with open(os.path.join(self.embedding_dir, 'faiss_index.pkl'), 'wb') as f:
pickle.dump({
'index': index,
'chunk_ids': chunk_ids,
'chunks': chunks
}, f)
print(f"Created index with {len(chunk_ids)} chunks")
def search(self,
query: str,
k: int = 5,
filter_categories: Optional[List[str]] = None,
rerank: bool = True) -> List[Dict[str, Any]]:
if self.index is None:
print("No index available. Please create an index first.")
return []
query_embedding = self.model.encode([query])[0]
D, I = self.index.search(np.array([query_embedding]).astype(np.float32), min(k * 2, len(self.chunk_ids)))
results = []
for i, idx in enumerate(I[0]):
chunk_id = self.chunk_ids[idx]
chunk = self.chunks[chunk_id]
if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories):
continue
result = {
'chunk_id': chunk_id,
'score': float(D[0][i]),
'chunk': chunk
}
results.append(result)
if rerank and results:
pairs = [(query, result['chunk']['content']) for result in results]
rerank_scores = self.reranker.predict(pairs)
for i, score in enumerate(rerank_scores):
results[i]['rerank_score'] = float(score)
results = sorted(results, key=lambda x: x['rerank_score'], reverse=True)
results = results[:k]
return results
def hybrid_search(self,
query: str,
k: int = 5,
filter_categories: Optional[List[str]] = None) -> List[Dict[str, Any]]:
vector_results = self.search(query, k=k, filter_categories=filter_categories, rerank=False)
keywords = query.lower().split()
keyword_scores = {}
for chunk_id, chunk_data in self.chunks.items():
chunk = chunk_data
content = (chunk['title'] + " " + chunk['content']).lower()
score = sum(content.count(keyword) for keyword in keywords)
if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories):
continue
keyword_scores[chunk_id] = score
keyword_results = sorted(
[{'chunk_id': chunk_id, 'score': score, 'chunk': self.chunks[chunk_id]}
for chunk_id, score in keyword_scores.items() if score > 0],
key=lambda x: x['score'],
reverse=True
)[:k]
seen_ids = set()
combined_results = []
for result in vector_results:
combined_results.append(result)
seen_ids.add(result['chunk_id'])
for result in keyword_results:
if result['chunk_id'] not in seen_ids:
combined_results.append(result)
seen_ids.add(result['chunk_id'])
combined_results = combined_results[:k]
if combined_results:
pairs = [(query, result['chunk']['content']) for result in combined_results]
rerank_scores = self.reranker.predict(pairs)
for i, score in enumerate(rerank_scores):
combined_results[i]['rerank_score'] = float(score)
combined_results = sorted(combined_results, key=lambda x: x['rerank_score'], reverse=True)
return combined_results