Image-Retrieval-System / src /vector_store.py
s1ngledoge's picture
upd
1c8d1ba
Raw
History Blame Contribute Delete
2.73 kB
from __future__ import annotations
import logging
from collections.abc import Iterable
from typing import Any
from src.config import ALL_CATEGORIES_LABEL, ConfigurationError, Settings, load_settings
logger = logging.getLogger(__name__)
def build_category_filter(category: str | None) -> str | None:
if not category or category == ALL_CATEGORIES_LABEL:
return None
escaped = category.replace("\\", "\\\\").replace("'", "\\'")
return f"category = '{escaped}'"
class VectorStore:
def __init__(self, settings: Settings | None = None) -> None:
self.settings = settings or load_settings()
self.settings.require_upstash()
try:
from upstash_vector import Index
except ImportError as exc:
raise ConfigurationError(
"The upstash-vector package is not installed. Run `pip install -r requirements.txt`."
) from exc
assert self.settings.upstash_url is not None
assert self.settings.upstash_token is not None
self.namespace = self.settings.upstash_namespace
self.index = Index(url=self.settings.upstash_url, token=self.settings.upstash_token)
logger.info("Connected to Upstash Vector namespace %s", self.namespace)
def upsert_image_vector(self, id: str, vector: list[float], metadata: dict[str, Any]) -> None:
self.upsert_many([{"id": id, "vector": vector, "metadata": metadata}])
def upsert_many(self, items: Iterable[dict[str, Any]]) -> None:
vectors = [
{"id": item["id"], "vector": item["vector"], "metadata": item["metadata"]}
for item in items
]
if not vectors:
return
logger.info("Upserting %d vector(s) to namespace %s", len(vectors), self.namespace)
self.index.upsert(vectors=vectors, namespace=self.namespace)
def query_vector(
self,
vector: list[float],
top_k: int,
category: str | None = None,
) -> list[Any]:
query_filter = build_category_filter(category)
kwargs: dict[str, Any] = {
"vector": vector,
"top_k": int(top_k),
"include_metadata": True,
"include_vectors": False,
"namespace": self.namespace,
}
if query_filter:
kwargs["filter"] = query_filter
logger.info(
"Querying Upstash Vector with top_k=%s category=%s",
top_k,
category or ALL_CATEGORIES_LABEL,
)
return list(self.index.query(**kwargs))
def clear_namespace(self) -> None:
logger.warning("Clearing Upstash Vector namespace %s", self.namespace)
self.index.reset(namespace=self.namespace)