Dhenenjay's picture
Upload folder using huggingface_hub
adcfb91 verified
# This tests a very minimal of test_add in test_add.py as a example based test
# instead of a property based test. We can use the delta to get the property
# test working and then enable
import random
from typing import List
import numpy as np
from chromadb.api import ClientAPI
import time
from chromadb.api.types import QueryResult
from chromadb.test.conftest import (
COMPACTION_SLEEP,
reset,
skip_if_not_cluster,
)
from chromadb.utils.distance_functions import l2
EPS = 1e-6
@skip_if_not_cluster()
def test_add(
client: ClientAPI,
) -> None:
seed = time.time()
random.seed(seed)
print("Generating data with seed ", seed)
reset(client)
collection = client.create_collection(
name="test",
metadata={"hnsw:construction_ef": 128, "hnsw:search_ef": 128, "hnsw:M": 128},
)
# Add 1000 records, where each embedding has 3 dimensions randomly generated
# between 0 and 1
ids = []
embeddings = []
for i in range(1000):
ids.append(str(i))
embeddings.append([random.random(), random.random(), random.random()])
collection.add(
ids=[str(i)],
embeddings=[embeddings[-1]], # type: ignore
)
random_query = [random.random(), random.random(), random.random()]
print("Generated data with seed ", seed)
# Query the collection with a random query
results = collection.query(
query_embeddings=[random_query], # type: ignore
n_results=10,
include=["distances"], # type: ignore[list-item]
)
# Check that the distances are correct in l2
ground_truth_distances = [
l2(np.array(random_query), np.array(embedding)) for embedding in embeddings
]
ground_truth_distances.sort()
retrieved_distances = results["distances"][0] # type: ignore
# Check that the query results are sorted by distance
for i in range(1, len(retrieved_distances)):
assert retrieved_distances[i - 1] <= retrieved_distances[i]
for i in range(len(retrieved_distances)):
assert np.allclose(ground_truth_distances[i], retrieved_distances[i], atol=EPS)
@skip_if_not_cluster()
def test_add_include_all_with_compaction_delay(client: ClientAPI) -> None:
seed = time.time()
random.seed(seed)
print("Generating data with seed ", seed)
reset(client)
collection = client.create_collection(
name="test_add_include_all_with_compaction_delay",
metadata={"hnsw:construction_ef": 128, "hnsw:search_ef": 128, "hnsw:M": 128},
)
ids = []
embeddings = []
for i in range(1000):
ids.append(str(i))
embeddings.append([random.random(), random.random(), random.random()])
collection.add(
ids=[str(i)],
embeddings=[embeddings[-1]], # type: ignore
documents=f"document_{i}",
)
time.sleep(COMPACTION_SLEEP) # Wait for the documents to be compacted
random_query_1 = [random.random(), random.random(), random.random()]
random_query_2 = [random.random(), random.random(), random.random()]
print("Generated data with seed ", seed)
# Query the collection with a random query
results = collection.query(
query_embeddings=[random_query_1, random_query_2], # type: ignore
n_results=10,
include=["metadatas", "documents", "distances", "embeddings"], # type: ignore[list-item]
)
ids_and_embeddings = list(zip(ids, embeddings))
def validate(results: QueryResult, query: List[float], result_index: int) -> None:
# Check that the distances are correct in l2
gt_ids_distances_embeddings = [
(id, sum((a - b) ** 2 for a, b in zip(embedding, query)), embedding)
for id, embedding in ids_and_embeddings
]
gt_ids_distances_embeddings.sort(key=lambda x: x[1])
retrieved_distances = results["distances"][result_index] # type: ignore
# Check that the query results are sorted by distance
for i in range(1, len(retrieved_distances)):
assert retrieved_distances[i - 1] <= retrieved_distances[i]
for i in range(len(retrieved_distances)):
assert abs(gt_ids_distances_embeddings[i][1] - retrieved_distances[i]) < EPS
# Check that the ids are correct
retrieved_ids = results["ids"][result_index]
for i in range(len(retrieved_ids)):
assert retrieved_ids[i] == gt_ids_distances_embeddings[i][0]
# Check that the documents are correct
if "documents" in results and results["documents"] is not None:
retrieved_documents = results["documents"][result_index]
for i in range(len(retrieved_documents)):
assert (
retrieved_documents[i]
== f"document_{gt_ids_distances_embeddings[i][0]}"
)
else:
assert False
# Check that the embeddings are correct
if "embeddings" in results and results["embeddings"] is not None:
retrieved_embeddings = results["embeddings"][result_index]
for i in range(len(retrieved_embeddings)):
# eps compare the embeddings
for j in range(3):
assert (
abs(
retrieved_embeddings[i][j]
- gt_ids_distances_embeddings[i][2][j]
)
< EPS
)
else:
assert False
validate(results, random_query_1, 0)
validate(results, random_query_2, 1)