Spaces:
Sleeping
Sleeping
File size: 2,171 Bytes
6a3ee0f c2657ee 6a3ee0f 8be4889 6a3ee0f 8be4889 6a3ee0f 8be4889 6a3ee0f 8be4889 6a3ee0f 8be4889 6a3ee0f 8be4889 6a3ee0f 8be4889 6a3ee0f 8be4889 6a3ee0f 8be4889 6a3ee0f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | import pickle
from pathlib import Path
import faiss
import numpy as np
from PIL import Image
class VectorStore:
"""Read/write the FAISS index and metadata, serve nearest-neighbour results."""
INDEX_PATH = Path("data/index.faiss")
META_PATH = Path("data/metadata.pkl")
def __init__(self):
self._index = None
self._metadata = []
self._try_load_from_disk()
def is_index_loaded(self):
return self._index is not None and self._index.ntotal > 0
def add(self, vectors, metadata):
if self._index is None:
dim = vectors.shape[1]
self._index = faiss.IndexFlatIP(dim)
self._index.add(vectors.astype("float32"))
self._metadata.extend(metadata)
self._save_to_disk()
def search(self, query_vec, k=6):
if not self.is_index_loaded():
return []
q = query_vec.astype("float32").reshape(1, -1)
scores, indices = self._index.search(q, k)
results = []
for score, idx in zip(scores[0], indices[0]):
if idx == -1:
continue
item = dict(self._metadata[idx])
item["score"] = float(score)
if "image" not in item and "image_url" in item:
item["image"] = self._load_image(item["image_url"])
results.append(item)
return results
def _try_load_from_disk(self):
if self.INDEX_PATH.exists() and self.META_PATH.exists():
self._index = faiss.read_index(str(self.INDEX_PATH))
with open(self.META_PATH, "rb") as f:
self._metadata = pickle.load(f)
def _save_to_disk(self):
self.INDEX_PATH.parent.mkdir(parents=True, exist_ok=True)
faiss.write_index(self._index, str(self.INDEX_PATH))
with open(self.META_PATH, "wb") as f:
pickle.dump(self._metadata, f)
@staticmethod
def _load_image(url):
try:
import requests
from io import BytesIO
resp = requests.get(url, timeout=5)
return Image.open(BytesIO(resp.content)).convert("RGB")
except Exception:
return None |