Spaces:
Sleeping
Sleeping
| __import__("pysqlite3") | |
| import sys | |
| sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") | |
| import uuid | |
| from collections import defaultdict | |
| from typing import Any, List | |
| import chromadb | |
| import numpy as np | |
| from chromadb import Collection | |
| from embeddings import Embedding | |
| from PIL.Image import Image | |
| from utils import base64_to_image | |
| class ChromaStore: | |
| def __init__( | |
| self, | |
| collection_name: str, | |
| storage_path: str = "./chroma", | |
| database: str = "database", | |
| metadata: dict = {"hnsw:space": "cosine"}, | |
| ) -> None: | |
| """Initiate Chromadb | |
| - collection_name(str): name of the collection | |
| - metadata(dict): available options for 'hnsw:space' are 'l2', 'ip' or 'cosine'. | |
| """ | |
| self.collection_name = collection_name | |
| self.metadata = metadata | |
| self.storage_path = storage_path | |
| self.database = database | |
| self.client = chromadb.PersistentClient(path=self.storage_path) | |
| def _health_check(self) -> bool: | |
| return isinstance(self.client.heartbeat(), int) | |
| def generate_embeddings( | |
| self, images: List[Image], embedding: Embedding | |
| ) -> np.ndarray: | |
| return embedding.encode_images(images) | |
| def create(self): | |
| collection = self.client.get_or_create_collection( | |
| name=self.collection_name, | |
| ) | |
| return collection | |
| def add( | |
| self, | |
| collection: Collection, | |
| embeddings: List[float], | |
| documents: List[str], | |
| ids: List[str], | |
| ): | |
| """Add embeddings, documents to index or collection. | |
| Args: | |
| - collection: created collection. | |
| - embeddings: list of image embeddings | |
| - documents: list of base64 string of images | |
| - ids: list of ids for images.""" | |
| try: | |
| collection.add( | |
| embeddings=embeddings, | |
| ids=ids, | |
| documents=documents, | |
| ) | |
| except Exception as e: | |
| raise Exception(f"Failed to add documents to Chroma store. {e}") | |
| def query( | |
| self, | |
| collection: Collection, | |
| query_embedding: List[float], | |
| top_k: int = 3, | |
| ) -> list: | |
| """Retrieve relevant images from chroma database. | |
| Args: | |
| - collection: created collection. | |
| - query_embedding: query image embedding. | |
| - top_k (int): top k images to retrieve. | |
| Returns: | |
| - list of images along with their score. | |
| """ | |
| result = collection.query(query_embeddings=query_embedding, n_results=top_k) | |
| relevant_images = [ | |
| base64_to_image(img_str) for img_str in result["documents"][0] | |
| ] | |
| scores = [round(score, 3) for score in result["distances"][0]] | |
| return list(zip(relevant_images, scores)) | |
| def delete(self, collection_name: str): | |
| try: | |
| self.client.delete_collection(collection_name) | |
| return True | |
| except Exception as e: | |
| raise Exception("Failed to delete collection", e) | |
| def collection_info(collection: Collection): | |
| info = defaultdict(str) | |
| info["count"] = collection.count() | |
| info["top_10_items"] = collection.peek() | |
| return info | |