File size: 4,326 Bytes
63670e0 e767989 63670e0 e767989 63670e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
from pymilvus import MilvusClient, Collection, connections
from os.path import join
import json
from typing import List, Dict, Union
import numpy as np
class VectorDBManager:
"""
A manager class for handling vector operations in a Milvus database.
This class provides methods for retrieving, searching, and managing vectors
in various collections within a Milvus database, such as finding vectors by
ID, searching for nearest neighbors, and listing collections.
Attributes:
params (dict): Configuration parameters for Milvus operations.
"""
def __init__(self):
"""
Initializes the VectorDBManager by loading database configuration parameters.
/
Establishes a connection to the Milvus database and loads settings from the
configuration file.
"""
# Load configuration for database
config_path_name = join("config", "db_config.json")
self._client = MilvusClient(uri=join("src", "db", "gars.db"))
# Establish default connection alias for Milvus operations
connections.connect(alias="default", uri=join("src", "db", "gars.db"))
# Load Milvus-specific parameters from the configuration file
with open(config_path_name, "r") as f:
self._params = json.load(f)
def find_by_id(self, collection_name: str, id: int) -> Union[Dict, None]:
"""
Retrieves a vector by its unique identifier from the specified collection.
Args:
collection_name (str): The name of the collection to search in.
id (int): The unique identifier of the vector to retrieve.
Returns:
Union[Dict, None]: The retrieved vector data as a dictionary with vector
data converted to a NumPy array, or None if not found.
Raises:
ValueError: If the specified vector is not found in the collection.
"""
# Query the collection by ID
res = self._client.get(collection_name=collection_name, ids=[id])
# Handle case where vector is not found
if res is None:
raise ValueError(
f"Vector with ID {id} in collection '{collection_name}' was not found!"
)
# Convert vector data to NumPy array for compatibility
res[0]["vector"] = np.asarray(res[0]["vector"])
return res[0]
def find_knn(
self,
collection_name: str,
query_vector: List[float],
num_neighbors: int,
) -> List[Dict]:
"""
Finds the k-nearest neighbors to a query vector in the specified collection.
Args:
collection_name (str): The name of the collection to search.
query_vector (List[float]): The query vector to search for similar items.
num_neighbors (int): The number of nearest neighbors to retrieve.
Returns:
List[Dict]: The search results containing the nearest neighbors as a list
of dictionaries with vector data and metadata.
"""
search_params = {
"metric_type": self._params["metric_type"],
"params": {},
}
# Perform KNN search using Milvus client
res = self._client.search(
collection_name=collection_name,
data=[query_vector],
anns_field="vector",
limit=num_neighbors,
search_param=search_params,
)
return res[0] # Return nearest neighbors list
def list_collections(self) -> List[str]:
"""
Lists all collections in the Milvus database.
Returns:
List[str]: A list of all collection names in the database.
"""
return self._client.list_collections()
def get_collection_size(self, collection_name: str) -> int:
"""
Retrieves the number of entries in a specified collection.
Args:
collection_name (str): The name of the collection to get the size for.
Returns:
int: The number of entries in the specified collection.
"""
# Get collection statistics and retrieve row count
collection_stats = self._client.get_collection_stats(
collection_name=collection_name
)
return collection_stats["row_count"]
|