File size: 2,171 Bytes
6a3ee0f
c2657ee
6a3ee0f
 
 
 
 
 
 
8be4889
6a3ee0f
 
 
 
 
8be4889
 
6a3ee0f
 
8be4889
6a3ee0f
 
8be4889
6a3ee0f
 
 
 
 
 
 
8be4889
6a3ee0f
 
 
 
 
 
8be4889
6a3ee0f
 
 
 
 
 
 
 
8be4889
6a3ee0f
 
 
 
 
8be4889
6a3ee0f
 
 
 
 
 
8be4889
6a3ee0f
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import pickle
from pathlib import Path

import faiss
import numpy as np
from PIL import Image


class VectorStore:
    """Read/write the FAISS index and metadata, serve nearest-neighbour results."""

    INDEX_PATH = Path("data/index.faiss")
    META_PATH  = Path("data/metadata.pkl")

    def __init__(self):
        self._index = None
        self._metadata = []
        self._try_load_from_disk()

    def is_index_loaded(self):
        return self._index is not None and self._index.ntotal > 0

    def add(self, vectors, metadata):
        if self._index is None:
            dim = vectors.shape[1]
            self._index = faiss.IndexFlatIP(dim)
        self._index.add(vectors.astype("float32"))
        self._metadata.extend(metadata)
        self._save_to_disk()

    def search(self, query_vec, k=6):
        if not self.is_index_loaded():
            return []
        q = query_vec.astype("float32").reshape(1, -1)
        scores, indices = self._index.search(q, k)
        results = []
        for score, idx in zip(scores[0], indices[0]):
            if idx == -1:
                continue
            item = dict(self._metadata[idx])
            item["score"] = float(score)
            if "image" not in item and "image_url" in item:
                item["image"] = self._load_image(item["image_url"])
            results.append(item)
        return results

    def _try_load_from_disk(self):
        if self.INDEX_PATH.exists() and self.META_PATH.exists():
            self._index = faiss.read_index(str(self.INDEX_PATH))
            with open(self.META_PATH, "rb") as f:
                self._metadata = pickle.load(f)

    def _save_to_disk(self):
        self.INDEX_PATH.parent.mkdir(parents=True, exist_ok=True)
        faiss.write_index(self._index, str(self.INDEX_PATH))
        with open(self.META_PATH, "wb") as f:
            pickle.dump(self._metadata, f)

    @staticmethod
    def _load_image(url):
        try:
            import requests
            from io import BytesIO
            resp = requests.get(url, timeout=5)
            return Image.open(BytesIO(resp.content)).convert("RGB")
        except Exception:
            return None