|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
import re |
|
|
from datetime import datetime |
|
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
|
|
from camel.storages.vectordb_storages import ( |
|
|
BaseVectorStorage, |
|
|
VectorDBQuery, |
|
|
VectorDBQueryResult, |
|
|
VectorDBStatus, |
|
|
VectorRecord, |
|
|
) |
|
|
from camel.utils import dependencies_required |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class MilvusStorage(BaseVectorStorage): |
|
|
r"""An implementation of the `BaseVectorStorage` for interacting with |
|
|
Milvus, a cloud-native vector search engine. |
|
|
|
|
|
The detailed information about Milvus is available at: |
|
|
`Milvus <https://milvus.io/docs/overview.md/>`_ |
|
|
|
|
|
Args: |
|
|
vector_dim (int): The dimenstion of storing vectors. |
|
|
url_and_api_key (Tuple[str, str]): Tuple containing |
|
|
the URL and API key for connecting to a remote Milvus instance. |
|
|
URL maps to Milvus uri concept, typically "endpoint:port". |
|
|
API key maps to Milvus token concept, for self-hosted it's |
|
|
"username:pwd", for Zilliz Cloud (fully-managed Milvus) it's API |
|
|
Key. |
|
|
collection_name (Optional[str], optional): Name for the collection in |
|
|
the Milvus. If not provided, set it to the current time with iso |
|
|
format. (default: :obj:`None`) |
|
|
**kwargs (Any): Additional keyword arguments for initializing |
|
|
`MilvusClient`. |
|
|
|
|
|
Raises: |
|
|
ImportError: If `pymilvus` package is not installed. |
|
|
""" |
|
|
|
|
|
@dependencies_required('pymilvus') |
|
|
def __init__( |
|
|
self, |
|
|
vector_dim: int, |
|
|
url_and_api_key: Tuple[str, str], |
|
|
collection_name: Optional[str] = None, |
|
|
**kwargs: Any, |
|
|
) -> None: |
|
|
from pymilvus import MilvusClient |
|
|
|
|
|
self._client: MilvusClient |
|
|
self._create_client(url_and_api_key, **kwargs) |
|
|
self.vector_dim = vector_dim |
|
|
self.collection_name = ( |
|
|
collection_name or self._generate_collection_name() |
|
|
) |
|
|
self._check_and_create_collection() |
|
|
|
|
|
def _create_client( |
|
|
self, |
|
|
url_and_api_key: Tuple[str, str], |
|
|
**kwargs: Any, |
|
|
) -> None: |
|
|
r"""Initializes the Milvus client with the provided connection details. |
|
|
|
|
|
Args: |
|
|
url_and_api_key (Tuple[str, str]): The URL and API key for the |
|
|
Milvus server. |
|
|
**kwargs: Additional keyword arguments passed to the Milvus client. |
|
|
""" |
|
|
from pymilvus import MilvusClient |
|
|
|
|
|
self._client = MilvusClient( |
|
|
uri=url_and_api_key[0], |
|
|
token=url_and_api_key[1], |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
def _check_and_create_collection(self) -> None: |
|
|
r"""Checks if the specified collection exists in Milvus and creates it |
|
|
if it doesn't, ensuring it matches the specified vector dimensionality. |
|
|
""" |
|
|
if self._collection_exists(self.collection_name): |
|
|
in_dim = self._get_collection_info(self.collection_name)[ |
|
|
"vector_dim" |
|
|
] |
|
|
if in_dim != self.vector_dim: |
|
|
|
|
|
raise ValueError( |
|
|
"Vector dimension of the existing collection " |
|
|
f'"{self.collection_name}" ({in_dim}) is different from ' |
|
|
f"the given embedding dim ({self.vector_dim})." |
|
|
) |
|
|
else: |
|
|
self._create_collection( |
|
|
collection_name=self.collection_name, |
|
|
) |
|
|
|
|
|
def _create_collection( |
|
|
self, |
|
|
collection_name: str, |
|
|
**kwargs: Any, |
|
|
) -> None: |
|
|
r"""Creates a new collection in the database. |
|
|
|
|
|
Args: |
|
|
collection_name (str): Name of the collection to be created. |
|
|
**kwargs (Any): Additional keyword arguments pass to create |
|
|
collection. |
|
|
""" |
|
|
|
|
|
from pymilvus import DataType |
|
|
|
|
|
|
|
|
schema = self._client.create_schema( |
|
|
auto_id=False, |
|
|
enable_dynamic_field=True, |
|
|
description='collection schema', |
|
|
) |
|
|
|
|
|
schema.add_field( |
|
|
field_name="id", |
|
|
datatype=DataType.VARCHAR, |
|
|
descrition='A unique identifier for the vector', |
|
|
is_primary=True, |
|
|
max_length=65535, |
|
|
) |
|
|
|
|
|
schema.add_field( |
|
|
field_name="vector", |
|
|
datatype=DataType.FLOAT_VECTOR, |
|
|
description='The numerical representation of the vector', |
|
|
dim=self.vector_dim, |
|
|
) |
|
|
schema.add_field( |
|
|
field_name="payload", |
|
|
datatype=DataType.JSON, |
|
|
description=( |
|
|
'Any additional metadata or information related' |
|
|
'to the vector' |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
self._client.create_collection( |
|
|
collection_name=collection_name, |
|
|
schema=schema, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
index_params = self._client.prepare_index_params() |
|
|
|
|
|
index_params.add_index( |
|
|
field_name="vector", |
|
|
metric_type="COSINE", |
|
|
index_type="AUTOINDEX", |
|
|
index_name="vector_index", |
|
|
) |
|
|
|
|
|
self._client.create_index( |
|
|
collection_name=collection_name, index_params=index_params |
|
|
) |
|
|
|
|
|
def _delete_collection( |
|
|
self, |
|
|
collection_name: str, |
|
|
) -> None: |
|
|
r"""Deletes an existing collection from the database. |
|
|
|
|
|
Args: |
|
|
collection (str): Name of the collection to be deleted. |
|
|
""" |
|
|
self._client.drop_collection(collection_name=collection_name) |
|
|
|
|
|
def _collection_exists(self, collection_name: str) -> bool: |
|
|
r"""Checks whether a collection with the specified name exists in the |
|
|
database. |
|
|
|
|
|
Args: |
|
|
collection_name (str): The name of the collection to check. |
|
|
|
|
|
Returns: |
|
|
bool: True if the collection exists, False otherwise. |
|
|
""" |
|
|
return self._client.has_collection(collection_name) |
|
|
|
|
|
def _generate_collection_name(self) -> str: |
|
|
r"""Generates a unique name for a new collection based on the current |
|
|
timestamp. Milvus collection names can only contain alphanumeric |
|
|
characters and underscores. |
|
|
|
|
|
Returns: |
|
|
str: A unique, valid collection name. |
|
|
""" |
|
|
timestamp = datetime.now().isoformat() |
|
|
transformed_name = re.sub(r'[^a-zA-Z0-9_]', '_', timestamp) |
|
|
valid_name = "Time" + transformed_name |
|
|
return valid_name |
|
|
|
|
|
def _get_collection_info(self, collection_name: str) -> Dict[str, Any]: |
|
|
r"""Retrieves details of an existing collection. |
|
|
|
|
|
Args: |
|
|
collection_name (str): Name of the collection to be checked. |
|
|
|
|
|
Returns: |
|
|
Dict[str, Any]: A dictionary containing details about the |
|
|
collection. |
|
|
""" |
|
|
vector_count = self._client.get_collection_stats(collection_name)[ |
|
|
'row_count' |
|
|
] |
|
|
collection_info = self._client.describe_collection(collection_name) |
|
|
collection_id = collection_info['collection_id'] |
|
|
|
|
|
dim_value = next( |
|
|
( |
|
|
field['params']['dim'] |
|
|
for field in collection_info['fields'] |
|
|
if field['description'] |
|
|
== 'The numerical representation of the vector' |
|
|
), |
|
|
None, |
|
|
) |
|
|
|
|
|
return { |
|
|
"id": collection_id, |
|
|
"vector_count": vector_count, |
|
|
"vector_dim": dim_value, |
|
|
} |
|
|
|
|
|
def _validate_and_convert_vectors( |
|
|
self, records: List[VectorRecord] |
|
|
) -> List[dict]: |
|
|
r"""Validates and converts VectorRecord instances to the format |
|
|
expected by Milvus. |
|
|
|
|
|
Args: |
|
|
records (List[VectorRecord]): List of vector records to validate |
|
|
and convert. |
|
|
|
|
|
Returns: |
|
|
List[dict]: A list of dictionaries formatted for Milvus insertion. |
|
|
""" |
|
|
|
|
|
validated_data = [] |
|
|
|
|
|
for record in records: |
|
|
record_dict = { |
|
|
"id": record.id, |
|
|
"payload": record.payload |
|
|
if record.payload is not None |
|
|
else '', |
|
|
"vector": record.vector, |
|
|
} |
|
|
validated_data.append(record_dict) |
|
|
|
|
|
return validated_data |
|
|
|
|
|
def add( |
|
|
self, |
|
|
records: List[VectorRecord], |
|
|
**kwargs, |
|
|
) -> None: |
|
|
r"""Adds a list of vectors to the specified collection. |
|
|
|
|
|
Args: |
|
|
records (List[VectorRecord]): List of vectors to be added. |
|
|
**kwargs (Any): Additional keyword arguments pass to insert. |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If there was an error in the addition process. |
|
|
""" |
|
|
validated_records = self._validate_and_convert_vectors(records) |
|
|
|
|
|
op_info = self._client.insert( |
|
|
collection_name=self.collection_name, |
|
|
data=validated_records, |
|
|
**kwargs, |
|
|
) |
|
|
logger.debug(f"Successfully added vectors in Milvus: {op_info}") |
|
|
|
|
|
def delete( |
|
|
self, |
|
|
ids: List[str], |
|
|
**kwargs: Any, |
|
|
) -> None: |
|
|
r"""Deletes a list of vectors identified by their IDs from the |
|
|
storage. If unsure of ids you can first query the collection to grab |
|
|
the corresponding data. |
|
|
|
|
|
Args: |
|
|
ids (List[str]): List of unique identifiers for the vectors to be |
|
|
deleted. |
|
|
**kwargs (Any): Additional keyword arguments passed to delete. |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If there is an error during the deletion process. |
|
|
""" |
|
|
|
|
|
op_info = self._client.delete( |
|
|
collection_name=self.collection_name, pks=ids, **kwargs |
|
|
) |
|
|
logger.debug(f"Successfully deleted vectors in Milvus: {op_info}") |
|
|
|
|
|
def status(self) -> VectorDBStatus: |
|
|
r"""Retrieves the current status of the Milvus collection. This method |
|
|
provides information about the collection, including its vector |
|
|
dimensionality and the total number of vectors stored. |
|
|
|
|
|
Returns: |
|
|
VectorDBStatus: An object containing information about the |
|
|
collection's status. |
|
|
""" |
|
|
status = self._get_collection_info(self.collection_name) |
|
|
return VectorDBStatus( |
|
|
vector_dim=status["vector_dim"], |
|
|
vector_count=status["vector_count"], |
|
|
) |
|
|
|
|
|
def query( |
|
|
self, |
|
|
query: VectorDBQuery, |
|
|
**kwargs: Any, |
|
|
) -> List[VectorDBQueryResult]: |
|
|
r"""Searches for similar vectors in the storage based on the provided |
|
|
query. |
|
|
|
|
|
Args: |
|
|
query (VectorDBQuery): The query object containing the search |
|
|
vector and the number of top similar vectors to retrieve. |
|
|
**kwargs (Any): Additional keyword arguments passed to search. |
|
|
|
|
|
Returns: |
|
|
List[VectorDBQueryResult]: A list of vectors retrieved from the |
|
|
storage based on similarity to the query vector. |
|
|
""" |
|
|
search_result = self._client.search( |
|
|
collection_name=self.collection_name, |
|
|
data=[query.query_vector], |
|
|
limit=query.top_k, |
|
|
output_fields=['vector', 'payload'], |
|
|
**kwargs, |
|
|
) |
|
|
query_results = [] |
|
|
for point in search_result: |
|
|
query_results.append( |
|
|
VectorDBQueryResult.create( |
|
|
similarity=(point[0]['distance']), |
|
|
id=str(point[0]['id']), |
|
|
payload=(point[0]['entity'].get('payload')), |
|
|
vector=point[0]['entity'].get('vector'), |
|
|
) |
|
|
) |
|
|
|
|
|
return query_results |
|
|
|
|
|
def clear(self) -> None: |
|
|
r"""Removes all vectors from the Milvus collection. This method |
|
|
deletes the existing collection and then recreates it with the same |
|
|
schema to effectively remove all stored vectors. |
|
|
""" |
|
|
self._delete_collection(self.collection_name) |
|
|
self._create_collection(collection_name=self.collection_name) |
|
|
|
|
|
def load(self) -> None: |
|
|
r"""Load the collection hosted on cloud service.""" |
|
|
self._client.load_collection(self.collection_name) |
|
|
|
|
|
@property |
|
|
def client(self) -> Any: |
|
|
r"""Provides direct access to the Milvus client. This property allows |
|
|
for direct interactions with the Milvus client for operations that are |
|
|
not covered by the `MilvusStorage` class. |
|
|
|
|
|
Returns: |
|
|
Any: The Milvus client instance. |
|
|
""" |
|
|
return self._client |
|
|
|