challenge-b / ingestion /vector_store.py
NEXAS's picture
Upload 16 files
49cf970 verified
import faiss
import numpy as np
import os
from fastembed import TextEmbedding
from typing import List, Tuple
class VectorStore:
def __init__(self, model_name: str = "BAAI/bge-small-en-v1.5", cache_dir: str = ".cache"):
self.encoder = TextEmbedding(model_name=model_name)
self.index = None
self.chunks = []
self.cache_dir = cache_dir
if not os.path.exists(self.cache_dir):
os.makedirs(self.cache_dir)
def build_index(self, chunks: List[str]):
"""
Embeds chunks and builds a FAISS index.
"""
self.chunks = chunks
embeddings = list(self.encoder.embed(chunks))
embeddings_np = np.array(embeddings).astype('float32')
dimension = embeddings_np.shape[1]
self.index = faiss.IndexFlatL2(dimension)
self.index.add(embeddings_np)
def save_index(self, key: str):
"""
Saves the FAISS index and chunks to the cache.
"""
if self.index is not None:
faiss.write_index(self.index, os.path.join(self.cache_dir, f"{key}.index"))
np.save(os.path.join(self.cache_dir, f"{key}_chunks.npy"), np.array(self.chunks))
def load_index(self, key: str) -> bool:
"""
Loads the FAISS index and chunks from the cache if available.
"""
index_path = os.path.join(self.cache_dir, f"{key}.index")
chunks_path = os.path.join(self.cache_dir, f"{key}_chunks.npy")
if os.path.exists(index_path) and os.path.exists(chunks_path):
self.index = faiss.read_index(index_path)
self.chunks = np.load(chunks_path, allow_pickle=True).tolist()
return True
return False
def search(self, query: str, top_k: int = 4) -> List[Tuple[str, float]]:
"""
Searches for the top-k chunks most relevant to the query.
"""
if self.index is None:
return []
query_embedding = list(self.encoder.embed([query]))[0]
query_embedding_np = np.array([query_embedding]).astype('float32')
distances, indices = self.index.search(query_embedding_np, top_k)
results = []
for i, idx in enumerate(indices[0]):
if idx != -1:
results.append((self.chunks[idx], float(distances[0][i])))
return results
if __name__ == "__main__":
# Test
vs = VectorStore()
vs.build_index(["Hello, world!", "The quick brown fox jumps over the lazy dog."])
results = vs.search("What animal jumps?")
for res, dist in results:
print(f"Result: {res} (Distance: {dist})")