Spaces:
Runtime error
Runtime error
| from typing import Optional, Union, TypeVar, List, Dict, Any, Tuple, cast | |
| from numpy.typing import NDArray | |
| import numpy as np | |
| from typing_extensions import Literal, TypedDict, Protocol | |
| import chromadb.errors as errors | |
| from chromadb.types import ( | |
| Metadata, | |
| UpdateMetadata, | |
| Vector, | |
| LiteralValue, | |
| LogicalOperator, | |
| WhereOperator, | |
| OperatorExpression, | |
| Where, | |
| WhereDocumentOperator, | |
| WhereDocument, | |
| ) | |
| from inspect import signature | |
| from tenacity import retry | |
| # Re-export types from chromadb.types | |
| __all__ = ["Metadata", "Where", "WhereDocument", "UpdateCollectionMetadata"] | |
| T = TypeVar("T") | |
| OneOrMany = Union[T, List[T]] | |
| # URIs | |
| URI = str | |
| URIs = List[URI] | |
| def maybe_cast_one_to_many_uri(target: OneOrMany[URI]) -> URIs: | |
| if isinstance(target, str): | |
| # One URI | |
| return cast(URIs, [target]) | |
| # Already a sequence | |
| return cast(URIs, target) | |
| # IDs | |
| ID = str | |
| IDs = List[ID] | |
| def maybe_cast_one_to_many_ids(target: OneOrMany[ID]) -> IDs: | |
| if isinstance(target, str): | |
| # One ID | |
| return cast(IDs, [target]) | |
| # Already a sequence | |
| return cast(IDs, target) | |
| # Embeddings | |
| Embedding = Vector | |
| Embeddings = List[Embedding] | |
| def maybe_cast_one_to_many_embedding(target: OneOrMany[Embedding]) -> Embeddings: | |
| if isinstance(target, List): | |
| # One Embedding | |
| if isinstance(target[0], (int, float)): | |
| return cast(Embeddings, [target]) | |
| # Already a sequence | |
| return cast(Embeddings, target) | |
| # Metadatas | |
| Metadatas = List[Metadata] | |
| def maybe_cast_one_to_many_metadata(target: OneOrMany[Metadata]) -> Metadatas: | |
| # One Metadata dict | |
| if isinstance(target, dict): | |
| return cast(Metadatas, [target]) | |
| # Already a sequence | |
| return cast(Metadatas, target) | |
| CollectionMetadata = Dict[str, Any] | |
| UpdateCollectionMetadata = UpdateMetadata | |
| # Documents | |
| Document = str | |
| Documents = List[Document] | |
| def is_document(target: Any) -> bool: | |
| if not isinstance(target, str): | |
| return False | |
| return True | |
| def maybe_cast_one_to_many_document(target: OneOrMany[Document]) -> Documents: | |
| # One Document | |
| if is_document(target): | |
| return cast(Documents, [target]) | |
| # Already a sequence | |
| return cast(Documents, target) | |
| # Images | |
| ImageDType = Union[np.uint, np.int_, np.float_] | |
| Image = NDArray[ImageDType] | |
| Images = List[Image] | |
| def is_image(target: Any) -> bool: | |
| if not isinstance(target, np.ndarray): | |
| return False | |
| if len(target.shape) < 2: | |
| return False | |
| return True | |
| def maybe_cast_one_to_many_image(target: OneOrMany[Image]) -> Images: | |
| if is_image(target): | |
| return cast(Images, [target]) | |
| # Already a sequence | |
| return cast(Images, target) | |
| Parameter = TypeVar("Parameter", Document, Image, Embedding, Metadata, ID) | |
| # This should ust be List[Literal["documents", "embeddings", "metadatas", "distances"]] | |
| # However, this provokes an incompatibility with the Overrides library and Python 3.7 | |
| Include = List[ | |
| Union[ | |
| Literal["documents"], | |
| Literal["embeddings"], | |
| Literal["metadatas"], | |
| Literal["distances"], | |
| Literal["uris"], | |
| Literal["data"], | |
| ] | |
| ] | |
| # Re-export types from chromadb.types | |
| LiteralValue = LiteralValue | |
| LogicalOperator = LogicalOperator | |
| WhereOperator = WhereOperator | |
| OperatorExpression = OperatorExpression | |
| Where = Where | |
| WhereDocumentOperator = WhereDocumentOperator | |
| Embeddable = Union[Documents, Images] | |
| D = TypeVar("D", bound=Embeddable, contravariant=True) | |
| Loadable = List[Optional[Image]] | |
| L = TypeVar("L", covariant=True, bound=Loadable) | |
| class GetResult(TypedDict): | |
| ids: List[ID] | |
| embeddings: Optional[List[Embedding]] | |
| documents: Optional[List[Document]] | |
| uris: Optional[URIs] | |
| data: Optional[Loadable] | |
| metadatas: Optional[List[Metadata]] | |
| class QueryResult(TypedDict): | |
| ids: List[IDs] | |
| embeddings: Optional[List[List[Embedding]]] | |
| documents: Optional[List[List[Document]]] | |
| uris: Optional[List[List[URI]]] | |
| data: Optional[List[Loadable]] | |
| metadatas: Optional[List[List[Metadata]]] | |
| distances: Optional[List[List[float]]] | |
| class IndexMetadata(TypedDict): | |
| dimensionality: int | |
| # The current number of elements in the index (total = additions - deletes) | |
| curr_elements: int | |
| # The auto-incrementing ID of the last inserted element, never decreases so | |
| # can be used as a count of total historical size. Should increase by 1 every add. | |
| # Assume cannot overflow | |
| total_elements_added: int | |
| time_created: float | |
| class EmbeddingFunction(Protocol[D]): | |
| def __call__(self, input: D) -> Embeddings: | |
| ... | |
| def __init_subclass__(cls) -> None: | |
| super().__init_subclass__() | |
| # Raise an exception if __call__ is not defined since it is expected to be defined | |
| call = getattr(cls, "__call__") | |
| def __call__(self: EmbeddingFunction[D], input: D) -> Embeddings: | |
| result = call(self, input) | |
| return validate_embeddings(maybe_cast_one_to_many_embedding(result)) | |
| setattr(cls, "__call__", __call__) | |
| def embed_with_retries(self, input: D, **retry_kwargs: Dict) -> Embeddings: | |
| return retry(**retry_kwargs)(self.__call__)(input) | |
| def validate_embedding_function( | |
| embedding_function: EmbeddingFunction[Embeddable], | |
| ) -> None: | |
| function_signature = signature( | |
| embedding_function.__class__.__call__ | |
| ).parameters.keys() | |
| protocol_signature = signature(EmbeddingFunction.__call__).parameters.keys() | |
| if not function_signature == protocol_signature: | |
| raise ValueError( | |
| f"Expected EmbeddingFunction.__call__ to have the following signature: {protocol_signature}, got {function_signature}\n" | |
| "Please see https://docs.trychroma.com/embeddings for details of the EmbeddingFunction interface.\n" | |
| "Please note the recent change to the EmbeddingFunction interface: https://docs.trychroma.com/migration#migration-to-0416---november-7-2023 \n" | |
| ) | |
| class DataLoader(Protocol[L]): | |
| def __call__(self, uris: URIs) -> L: | |
| ... | |
| def validate_ids(ids: IDs) -> IDs: | |
| """Validates ids to ensure it is a list of strings""" | |
| if not isinstance(ids, list): | |
| raise ValueError(f"Expected IDs to be a list, got {ids}") | |
| if len(ids) == 0: | |
| raise ValueError(f"Expected IDs to be a non-empty list, got {ids}") | |
| seen = set() | |
| dups = set() | |
| for id_ in ids: | |
| if not isinstance(id_, str): | |
| raise ValueError(f"Expected ID to be a str, got {id_}") | |
| if id_ in seen: | |
| dups.add(id_) | |
| else: | |
| seen.add(id_) | |
| if dups: | |
| n_dups = len(dups) | |
| if n_dups < 10: | |
| example_string = ", ".join(dups) | |
| message = ( | |
| f"Expected IDs to be unique, found duplicates of: {example_string}" | |
| ) | |
| else: | |
| examples = [] | |
| for idx, dup in enumerate(dups): | |
| examples.append(dup) | |
| if idx == 10: | |
| break | |
| example_string = ( | |
| f"{', '.join(examples[:5])}, ..., {', '.join(examples[-5:])}" | |
| ) | |
| message = f"Expected IDs to be unique, found {n_dups} duplicated IDs: {example_string}" | |
| raise errors.DuplicateIDError(message) | |
| return ids | |
| def validate_metadata(metadata: Metadata) -> Metadata: | |
| """Validates metadata to ensure it is a dictionary of strings to strings, ints, floats or bools""" | |
| if not isinstance(metadata, dict) and metadata is not None: | |
| raise ValueError(f"Expected metadata to be a dict or None, got {metadata}") | |
| if metadata is None: | |
| return metadata | |
| if len(metadata) == 0: | |
| raise ValueError(f"Expected metadata to be a non-empty dict, got {metadata}") | |
| for key, value in metadata.items(): | |
| if not isinstance(key, str): | |
| raise TypeError( | |
| f"Expected metadata key to be a str, got {key} which is a {type(key)}" | |
| ) | |
| # isinstance(True, int) evaluates to True, so we need to check for bools separately | |
| if not isinstance(value, bool) and not isinstance(value, (str, int, float)): | |
| raise ValueError( | |
| f"Expected metadata value to be a str, int, float or bool, got {value} which is a {type(value)}" | |
| ) | |
| return metadata | |
| def validate_update_metadata(metadata: UpdateMetadata) -> UpdateMetadata: | |
| """Validates metadata to ensure it is a dictionary of strings to strings, ints, floats or bools""" | |
| if not isinstance(metadata, dict) and metadata is not None: | |
| raise ValueError(f"Expected metadata to be a dict or None, got {metadata}") | |
| if metadata is None: | |
| return metadata | |
| if len(metadata) == 0: | |
| raise ValueError(f"Expected metadata to be a non-empty dict, got {metadata}") | |
| for key, value in metadata.items(): | |
| if not isinstance(key, str): | |
| raise ValueError(f"Expected metadata key to be a str, got {key}") | |
| # isinstance(True, int) evaluates to True, so we need to check for bools separately | |
| if not isinstance(value, bool) and not isinstance( | |
| value, (str, int, float, type(None)) | |
| ): | |
| raise ValueError( | |
| f"Expected metadata value to be a str, int, or float, got {value}" | |
| ) | |
| return metadata | |
| def validate_metadatas(metadatas: Metadatas) -> Metadatas: | |
| """Validates metadatas to ensure it is a list of dictionaries of strings to strings, ints, floats or bools""" | |
| if not isinstance(metadatas, list): | |
| raise ValueError(f"Expected metadatas to be a list, got {metadatas}") | |
| for metadata in metadatas: | |
| validate_metadata(metadata) | |
| return metadatas | |
| def validate_where(where: Where) -> Where: | |
| """ | |
| Validates where to ensure it is a dictionary of strings to strings, ints, floats or operator expressions, | |
| or in the case of $and and $or, a list of where expressions | |
| """ | |
| if not isinstance(where, dict): | |
| raise ValueError(f"Expected where to be a dict, got {where}") | |
| if len(where) != 1: | |
| raise ValueError(f"Expected where to have exactly one operator, got {where}") | |
| for key, value in where.items(): | |
| if not isinstance(key, str): | |
| raise ValueError(f"Expected where key to be a str, got {key}") | |
| if ( | |
| key != "$and" | |
| and key != "$or" | |
| and key != "$in" | |
| and key != "$nin" | |
| and not isinstance(value, (str, int, float, dict)) | |
| ): | |
| raise ValueError( | |
| f"Expected where value to be a str, int, float, or operator expression, got {value}" | |
| ) | |
| if key == "$and" or key == "$or": | |
| if not isinstance(value, list): | |
| raise ValueError( | |
| f"Expected where value for $and or $or to be a list of where expressions, got {value}" | |
| ) | |
| if len(value) <= 1: | |
| raise ValueError( | |
| f"Expected where value for $and or $or to be a list with at least two where expressions, got {value}" | |
| ) | |
| for where_expression in value: | |
| validate_where(where_expression) | |
| # Value is a operator expression | |
| if isinstance(value, dict): | |
| # Ensure there is only one operator | |
| if len(value) != 1: | |
| raise ValueError( | |
| f"Expected operator expression to have exactly one operator, got {value}" | |
| ) | |
| for operator, operand in value.items(): | |
| # Only numbers can be compared with gt, gte, lt, lte | |
| if operator in ["$gt", "$gte", "$lt", "$lte"]: | |
| if not isinstance(operand, (int, float)): | |
| raise ValueError( | |
| f"Expected operand value to be an int or a float for operator {operator}, got {operand}" | |
| ) | |
| if operator in ["$in", "$nin"]: | |
| if not isinstance(operand, list): | |
| raise ValueError( | |
| f"Expected operand value to be an list for operator {operator}, got {operand}" | |
| ) | |
| if operator not in [ | |
| "$gt", | |
| "$gte", | |
| "$lt", | |
| "$lte", | |
| "$ne", | |
| "$eq", | |
| "$in", | |
| "$nin", | |
| ]: | |
| raise ValueError( | |
| f"Expected where operator to be one of $gt, $gte, $lt, $lte, $ne, $eq, $in, $nin, " | |
| f"got {operator}" | |
| ) | |
| if not isinstance(operand, (str, int, float, list)): | |
| raise ValueError( | |
| f"Expected where operand value to be a str, int, float, or list of those type, got {operand}" | |
| ) | |
| if isinstance(operand, list) and ( | |
| len(operand) == 0 | |
| or not all(isinstance(x, type(operand[0])) for x in operand) | |
| ): | |
| raise ValueError( | |
| f"Expected where operand value to be a non-empty list, and all values to obe of the same type " | |
| f"got {operand}" | |
| ) | |
| return where | |
| def validate_where_document(where_document: WhereDocument) -> WhereDocument: | |
| """ | |
| Validates where_document to ensure it is a dictionary of WhereDocumentOperator to strings, or in the case of $and and $or, | |
| a list of where_document expressions | |
| """ | |
| if not isinstance(where_document, dict): | |
| raise ValueError( | |
| f"Expected where document to be a dictionary, got {where_document}" | |
| ) | |
| if len(where_document) != 1: | |
| raise ValueError( | |
| f"Expected where document to have exactly one operator, got {where_document}" | |
| ) | |
| for operator, operand in where_document.items(): | |
| if operator not in ["$contains", "$not_contains", "$and", "$or"]: | |
| raise ValueError( | |
| f"Expected where document operator to be one of $contains, $and, $or, got {operator}" | |
| ) | |
| if operator == "$and" or operator == "$or": | |
| if not isinstance(operand, list): | |
| raise ValueError( | |
| f"Expected document value for $and or $or to be a list of where document expressions, got {operand}" | |
| ) | |
| if len(operand) <= 1: | |
| raise ValueError( | |
| f"Expected document value for $and or $or to be a list with at least two where document expressions, got {operand}" | |
| ) | |
| for where_document_expression in operand: | |
| validate_where_document(where_document_expression) | |
| # Value is a $contains operator | |
| elif not isinstance(operand, str): | |
| raise ValueError( | |
| f"Expected where document operand value for operator $contains to be a str, got {operand}" | |
| ) | |
| elif len(operand) == 0: | |
| raise ValueError( | |
| "Expected where document operand value for operator $contains to be a non-empty str" | |
| ) | |
| return where_document | |
| def validate_include(include: Include, allow_distances: bool) -> Include: | |
| """Validates include to ensure it is a list of strings. Since get does not allow distances, allow_distances is used | |
| to control if distances is allowed""" | |
| if not isinstance(include, list): | |
| raise ValueError(f"Expected include to be a list, got {include}") | |
| for item in include: | |
| if not isinstance(item, str): | |
| raise ValueError(f"Expected include item to be a str, got {item}") | |
| allowed_values = ["embeddings", "documents", "metadatas", "uris", "data"] | |
| if allow_distances: | |
| allowed_values.append("distances") | |
| if item not in allowed_values: | |
| raise ValueError( | |
| f"Expected include item to be one of {', '.join(allowed_values)}, got {item}" | |
| ) | |
| return include | |
| def validate_n_results(n_results: int) -> int: | |
| """Validates n_results to ensure it is a positive Integer. Since hnswlib does not allow n_results to be negative.""" | |
| # Check Number of requested results | |
| if not isinstance(n_results, int): | |
| raise ValueError( | |
| f"Expected requested number of results to be a int, got {n_results}" | |
| ) | |
| if n_results <= 0: | |
| raise TypeError( | |
| f"Number of requested results {n_results}, cannot be negative, or zero." | |
| ) | |
| return n_results | |
| def validate_embeddings(embeddings: Embeddings) -> Embeddings: | |
| """Validates embeddings to ensure it is a list of list of ints, or floats""" | |
| if not isinstance(embeddings, list): | |
| raise ValueError(f"Expected embeddings to be a list, got {embeddings}") | |
| if len(embeddings) == 0: | |
| raise ValueError( | |
| f"Expected embeddings to be a list with at least one item, got {embeddings}" | |
| ) | |
| if not all([isinstance(e, list) for e in embeddings]): | |
| raise ValueError( | |
| f"Expected each embedding in the embeddings to be a list, got {embeddings}" | |
| ) | |
| for i,embedding in enumerate(embeddings): | |
| if len(embedding) == 0: | |
| raise ValueError( | |
| f"Expected each embedding in the embeddings to be a non-empty list, got empty embedding at pos {i}" | |
| ) | |
| if not all( | |
| [ | |
| isinstance(value, (int, float)) and not isinstance(value, bool) | |
| for value in embedding | |
| ] | |
| ): | |
| raise ValueError( | |
| f"Expected each value in the embedding to be a int or float, got {embeddings}" | |
| ) | |
| return embeddings | |
| def validate_batch( | |
| batch: Tuple[ | |
| IDs, | |
| Optional[Embeddings], | |
| Optional[Metadatas], | |
| Optional[Documents], | |
| Optional[URIs], | |
| ], | |
| limits: Dict[str, Any], | |
| ) -> None: | |
| if len(batch[0]) > limits["max_batch_size"]: | |
| raise ValueError( | |
| f"Batch size {len(batch[0])} exceeds maximum batch size {limits['max_batch_size']}" | |
| ) | |