Spaces:
Runtime error
Runtime error
| from typing import Generator, cast | |
| import numpy as np | |
| import pytest | |
| import chromadb | |
| from chromadb.api.types import ( | |
| Embeddable, | |
| EmbeddingFunction, | |
| Embeddings, | |
| Image, | |
| Document, | |
| ) | |
| from chromadb.test.property.strategies import hashing_embedding_function | |
| from chromadb.test.property.invariants import _exact_distances | |
| # A 'standard' multimodal embedding function, which converts inputs to strings | |
| # then hashes them to a fixed dimension. | |
| class hashing_multimodal_ef(EmbeddingFunction[Embeddable]): | |
| def __init__(self) -> None: | |
| self._hef = hashing_embedding_function(dim=10, dtype=np.float_) | |
| def __call__(self, input: Embeddable) -> Embeddings: | |
| to_texts = [str(i) for i in input] | |
| embeddings = np.array(self._hef(to_texts)) | |
| # Normalize the embeddings | |
| # This is so we can generate random unit vectors and have them be close to the embeddings | |
| embeddings /= np.linalg.norm(embeddings, axis=1, keepdims=True) | |
| return cast(Embeddings, embeddings.tolist()) | |
| def random_image() -> Image: | |
| return np.random.randint(0, 255, size=(10, 10, 3), dtype=np.int32) | |
| def random_document() -> Document: | |
| return str(random_image()) | |
| def multimodal_collection( | |
| default_ef: EmbeddingFunction[Embeddable] = hashing_multimodal_ef(), | |
| ) -> Generator[chromadb.Collection, None, None]: | |
| client = chromadb.Client() | |
| collection = client.create_collection( | |
| name="multimodal_collection", embedding_function=default_ef | |
| ) | |
| yield collection | |
| client.clear_system_cache() | |
| # Test adding and querying of a multimodal collection consisting of images and documents | |
| def test_multimodal( | |
| multimodal_collection: chromadb.Collection, | |
| default_ef: EmbeddingFunction[Embeddable] = hashing_multimodal_ef(), | |
| n_examples: int = 10, | |
| n_query_results: int = 3, | |
| ) -> None: | |
| # Fix numpy's random seed for reproducibility | |
| random_state = np.random.get_state() | |
| np.random.seed(0) | |
| image_ids = [str(i) for i in range(n_examples)] | |
| images = [random_image() for _ in range(n_examples)] | |
| image_embeddings = default_ef(images) | |
| document_ids = [str(i) for i in range(n_examples, 2 * n_examples)] | |
| documents = [random_document() for _ in range(n_examples)] | |
| document_embeddings = default_ef(documents) | |
| # Trying to add a document and an image at the same time should fail | |
| with pytest.raises( | |
| ValueError, match="You can only provide documents or images, not both." | |
| ): | |
| multimodal_collection.add( | |
| ids=image_ids[0], documents=documents[0], images=images[0] | |
| ) | |
| # Add some documents | |
| multimodal_collection.add(ids=document_ids, documents=documents) | |
| # Add some images | |
| multimodal_collection.add(ids=image_ids, images=images) | |
| # get() should return all the documents and images | |
| # ids corresponding to images should not have documents | |
| get_result = multimodal_collection.get(include=["documents"]) | |
| assert len(get_result["ids"]) == len(document_ids) + len(image_ids) | |
| for i, id in enumerate(get_result["ids"]): | |
| assert id in document_ids or id in image_ids | |
| assert get_result["documents"] is not None | |
| if id in document_ids: | |
| assert get_result["documents"][i] == documents[document_ids.index(id)] | |
| if id in image_ids: | |
| assert get_result["documents"][i] is None | |
| # Generate a random query image | |
| query_image = random_image() | |
| query_image_embedding = default_ef([query_image]) | |
| image_neighbor_indices, _ = _exact_distances( | |
| query_image_embedding, image_embeddings + document_embeddings | |
| ) | |
| # Get the ids of the nearest neighbors | |
| nearest_image_neighbor_ids = [ | |
| image_ids[i] if i < n_examples else document_ids[i % n_examples] | |
| for i in image_neighbor_indices[0][:n_query_results] | |
| ] | |
| # Generate a random query document | |
| query_document = random_document() | |
| query_document_embedding = default_ef([query_document]) | |
| document_neighbor_indices, _ = _exact_distances( | |
| query_document_embedding, image_embeddings + document_embeddings | |
| ) | |
| nearest_document_neighbor_ids = [ | |
| image_ids[i] if i < n_examples else document_ids[i % n_examples] | |
| for i in document_neighbor_indices[0][:n_query_results] | |
| ] | |
| # Querying with both images and documents should fail | |
| with pytest.raises(ValueError): | |
| multimodal_collection.query( | |
| query_images=[query_image], query_texts=[query_document] | |
| ) | |
| # Query with images | |
| query_result = multimodal_collection.query( | |
| query_images=[query_image], n_results=n_query_results, include=["documents"] | |
| ) | |
| assert query_result["ids"][0] == nearest_image_neighbor_ids | |
| # Query with documents | |
| query_result = multimodal_collection.query( | |
| query_texts=[query_document], n_results=n_query_results, include=["documents"] | |
| ) | |
| assert query_result["ids"][0] == nearest_document_neighbor_ids | |
| np.random.set_state(random_state) | |
| def test_multimodal_update_with_image( | |
| multimodal_collection: chromadb.Collection, | |
| ) -> None: | |
| # Updating an entry with an existing document should remove the documentß | |
| document = random_document() | |
| image = random_image() | |
| id = "0" | |
| multimodal_collection.add(ids=id, documents=document) | |
| multimodal_collection.update(ids=id, images=image) | |
| get_result = multimodal_collection.get(ids=id, include=["documents"]) | |
| assert get_result["documents"] is not None | |
| assert get_result["documents"][0] is None | |