Atelier-AI / database.py
Priyanshiiiii's picture
Update database.py
8be4889 verified
import pickle
from pathlib import Path
import faiss
import numpy as np
from PIL import Image
class VectorStore:
"""Read/write the FAISS index and metadata, serve nearest-neighbour results."""
INDEX_PATH = Path("data/index.faiss")
META_PATH = Path("data/metadata.pkl")
def __init__(self):
self._index = None
self._metadata = []
self._try_load_from_disk()
def is_index_loaded(self):
return self._index is not None and self._index.ntotal > 0
def add(self, vectors, metadata):
if self._index is None:
dim = vectors.shape[1]
self._index = faiss.IndexFlatIP(dim)
self._index.add(vectors.astype("float32"))
self._metadata.extend(metadata)
self._save_to_disk()
def search(self, query_vec, k=6):
if not self.is_index_loaded():
return []
q = query_vec.astype("float32").reshape(1, -1)
scores, indices = self._index.search(q, k)
results = []
for score, idx in zip(scores[0], indices[0]):
if idx == -1:
continue
item = dict(self._metadata[idx])
item["score"] = float(score)
if "image" not in item and "image_url" in item:
item["image"] = self._load_image(item["image_url"])
results.append(item)
return results
def _try_load_from_disk(self):
if self.INDEX_PATH.exists() and self.META_PATH.exists():
self._index = faiss.read_index(str(self.INDEX_PATH))
with open(self.META_PATH, "rb") as f:
self._metadata = pickle.load(f)
def _save_to_disk(self):
self.INDEX_PATH.parent.mkdir(parents=True, exist_ok=True)
faiss.write_index(self._index, str(self.INDEX_PATH))
with open(self.META_PATH, "wb") as f:
pickle.dump(self._metadata, f)
@staticmethod
def _load_image(url):
try:
import requests
from io import BytesIO
resp = requests.get(url, timeout=5)
return Image.open(BytesIO(resp.content)).convert("RGB")
except Exception:
return None