Spaces:
Runtime error
Runtime error
| import faiss | |
| import numpy as np | |
| import pickle | |
| import os | |
| class FAISSRetriever: | |
| def __init__(self): | |
| self.index = None | |
| self.documents = [] | |
| self.index_path = "faiss_index.pkl" | |
| if os.path.exists(self.index_path): | |
| self.load_index() | |
| def add_documents(self, texts, embeddings): | |
| if self.index is None: | |
| self.index = faiss.IndexFlatL2(embeddings.shape[1]) | |
| self.index.add(embeddings) | |
| self.documents.extend(texts) | |
| self.save_index() | |
| def search(self, image_emb=None, text_emb=None, k=3): | |
| if self.index is None: | |
| return [] | |
| if image_emb is not None and text_emb is not None: | |
| query = np.concatenate([image_emb, text_emb]) | |
| elif image_emb is not None: | |
| query = image_emb | |
| else: | |
| query = text_emb | |
| _, indices = self.index.search(query.reshape(1, -1), k) | |
| return [self.documents[i] for i in indices[0] if i < len(self.documents)] | |
| def save_index(self): | |
| with open(self.index_path, "wb") as f: | |
| pickle.dump((faiss.serialize_index(self.index), self.documents), f) | |
| def load_index(self): | |
| with open(self.index_path, "rb") as f: | |
| data = pickle.load(f) | |
| self.index = faiss.deserialize_index(data[0]) | |
| self.documents = data[1] |