Multimodal-rag-chatbot / utils /retriever.py
Advait3009's picture
Create utils/retriever.py
d931baa verified
import faiss
import numpy as np
import pickle
import os
class FAISSRetriever:
def __init__(self):
self.index = None
self.documents = []
self.index_path = "faiss_index.pkl"
if os.path.exists(self.index_path):
self.load_index()
def add_documents(self, texts, embeddings):
if self.index is None:
self.index = faiss.IndexFlatL2(embeddings.shape[1])
self.index.add(embeddings)
self.documents.extend(texts)
self.save_index()
def search(self, image_emb=None, text_emb=None, k=3):
if self.index is None:
return []
if image_emb is not None and text_emb is not None:
query = np.concatenate([image_emb, text_emb])
elif image_emb is not None:
query = image_emb
else:
query = text_emb
_, indices = self.index.search(query.reshape(1, -1), k)
return [self.documents[i] for i in indices[0] if i < len(self.documents)]
def save_index(self):
with open(self.index_path, "wb") as f:
pickle.dump((faiss.serialize_index(self.index), self.documents), f)
def load_index(self):
with open(self.index_path, "rb") as f:
data = pickle.load(f)
self.index = faiss.deserialize_index(data[0])
self.documents = data[1]