Spaces:
Runtime error
Runtime error
| from pymilvus import MilvusClient, DataType | |
| try: | |
| from milvus import default_server # Milvus Lite | |
| except Exception: | |
| default_server = None | |
| import numpy as np | |
| import concurrent.futures | |
| from pymilvus import Collection | |
| import os | |
| class MilvusManager: | |
| def __init__(self, milvus_uri, collection_name, create_collection, dim=128): | |
| #import environ variables from .env | |
| import dotenv | |
| # Load the .env file | |
| dotenv_file = dotenv.find_dotenv() | |
| dotenv.load_dotenv(dotenv_file) | |
| # Start embedded Milvus Lite server and connect locally | |
| if default_server is not None: | |
| try: | |
| # Optionally set base dir here if desired, e.g. default_server.set_base_dir('volumes/milvus_lite') | |
| default_server.start() | |
| except Exception: | |
| pass | |
| local_uri = f"http://127.0.0.1:{default_server.listen_port}" | |
| self.client = MilvusClient(uri=local_uri) | |
| else: | |
| # Fallback to standard local server (assumes docker-compose or system service) | |
| self.client = MilvusClient(uri="http://127.0.0.1:19530") | |
| self.collection_name = collection_name | |
| self.dim = dim | |
| if self.client.has_collection(collection_name=self.collection_name): | |
| self.client.load_collection(collection_name=self.collection_name) | |
| print("Loaded existing collection.") | |
| elif create_collection: | |
| self.create_collection() | |
| self.create_index() | |
| def create_collection(self): | |
| if self.client.has_collection(collection_name=self.collection_name): | |
| print("Collection already exists.") | |
| return | |
| schema = self.client.create_schema( | |
| auto_id=True, | |
| enable_dynamic_fields=True, | |
| ) | |
| schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True) | |
| schema.add_field( | |
| field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim | |
| ) | |
| schema.add_field(field_name="seq_id", datatype=DataType.INT16) | |
| schema.add_field(field_name="doc_id", datatype=DataType.INT64) | |
| schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535) | |
| self.client.create_collection( | |
| collection_name=self.collection_name, schema=schema | |
| ) | |
| def create_index(self): | |
| index_params = self.client.prepare_index_params() | |
| index_params.add_index( | |
| field_name="vector", | |
| index_name="vector_index", | |
| index_type="HNSW", #use HNSW option if got more mem, if not use IVF for faster processing | |
| metric_type=os.environ["metrictype"], #"IP" | |
| params={ | |
| "M": int(os.environ["mnum"]), #M:16 for HNSW, capital M | |
| "efConstruction": int(os.environ["efnum"]), #500 for HNSW | |
| }, | |
| ) | |
| self.client.create_index( | |
| collection_name=self.collection_name, index_params=index_params, sync=True | |
| ) | |
| def search(self, data, topk): | |
| # Retrieve all collection names from the Milvus client. | |
| collections = self.client.list_collections() | |
| # Set search parameters (here, using Inner Product metric). | |
| search_params = {"metric_type": os.environ["metrictype"], "params": {}} #default metric type is "IP" | |
| # Set to store unique (doc_id, collection_name) pairs across all collections. | |
| doc_collection_pairs = set() | |
| # Query each collection individually | |
| for collection in collections: | |
| self.client.load_collection(collection_name=collection) | |
| print("collection loaded:"+ collection) | |
| results = self.client.search( | |
| collection, | |
| data, | |
| limit=int(os.environ["topk"]), # Adjust limit per collection as needed. (default is 50) | |
| output_fields=["vector", "seq_id", "doc_id"], | |
| search_params=search_params, | |
| ) | |
| # Accumulate document IDs along with their originating collection. | |
| for r_id in range(len(results)): | |
| for r in range(len(results[r_id])): | |
| doc_id = results[r_id][r]["entity"]["doc_id"] | |
| doc_collection_pairs.add((doc_id, collection)) | |
| scores = [] | |
| def rerank_single_doc(doc_id, data, client, collection_name): | |
| # Query for detailed document vectors in the given collection. | |
| doc_colbert_vecs = client.query( | |
| collection_name=collection_name, | |
| filter=f"doc_id in [{doc_id}, {doc_id + 1}]", | |
| output_fields=["seq_id", "vector", "doc"], | |
| limit=16380, | |
| ) | |
| # Stack the vectors for dot product computation. | |
| doc_vecs = np.vstack( | |
| [doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))] | |
| ) | |
| # Compute a similarity score via dot product. | |
| score = np.dot(data, doc_vecs.T).max(1).sum() | |
| return (score, doc_id, collection_name) | |
| # Use a thread pool to rerank each document concurrently. | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor: | |
| futures = { | |
| executor.submit(rerank_single_doc, doc_id, data, self.client, collection): (doc_id, collection) | |
| for doc_id, collection in doc_collection_pairs | |
| } | |
| for future in concurrent.futures.as_completed(futures): | |
| score, doc_id, collection = future.result() | |
| scores.append((score, doc_id, collection)) | |
| #doc_id is page number! | |
| # Sort the reranked results by score in descending order. | |
| scores.sort(key=lambda x: x[0], reverse=True) | |
| # Unload the collection after search to free memory. | |
| self.client.release_collection(collection_name=collection) | |
| return scores[:topk] if len(scores) >= topk else scores #topk is the number of scores to return back | |
| """ | |
| search_params = {"metric_type": "IP", "params": {}} | |
| results = self.client.search( | |
| self.collection_name, | |
| data, | |
| limit=50, | |
| output_fields=["vector", "seq_id", "doc_id"], | |
| search_params=search_params, | |
| ) | |
| doc_ids = {result["entity"]["doc_id"] for result in results[0]} | |
| scores = [] | |
| def rerank_single_doc(doc_id, data, client, collection_name): | |
| doc_colbert_vecs = client.query( | |
| collection_name=collection_name, | |
| filter=f"doc_id in [{doc_id}, {doc_id + 1}]", | |
| output_fields=["seq_id", "vector", "doc"], | |
| limit=1000, | |
| ) | |
| doc_vecs = np.vstack( | |
| [doc["vector"] for doc in doc_colbert_vecs] | |
| ) | |
| score = np.dot(data, doc_vecs.T).max(1).sum() | |
| return score, doc_id | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor: | |
| futures = { | |
| executor.submit( | |
| rerank_single_doc, doc_id, data, self.client, self.collection_name | |
| ): doc_id | |
| for doc_id in doc_ids | |
| } | |
| for future in concurrent.futures.as_completed(futures): | |
| score, doc_id = future.result() | |
| scores.append((score, doc_id)) | |
| scores.sort(key=lambda x: x[0], reverse=True) | |
| return scores[:topk] | |
| """ | |
| def insert(self, data): | |
| colbert_vecs = data["colbert_vecs"] | |
| seq_length = len(colbert_vecs) | |
| doc_ids = [data["doc_id"]] * seq_length | |
| seq_ids = list(range(seq_length)) | |
| docs = [""] * seq_length | |
| docs[0] = data["filepath"] | |
| self.client.insert( | |
| self.collection_name, | |
| [ | |
| { | |
| "vector": colbert_vecs[i], | |
| "seq_id": seq_ids[i], | |
| "doc_id": doc_ids[i], | |
| "doc": docs[i], | |
| } | |
| for i in range(seq_length) | |
| ], | |
| ) | |
| def get_images_as_doc(self, images_with_vectors): | |
| return [ | |
| { | |
| "colbert_vecs": image["colbert_vecs"], | |
| "doc_id": idx, | |
| "filepath": image["filepath"], | |
| } | |
| for idx, image in enumerate(images_with_vectors) | |
| ] | |
| def insert_images_data(self, image_data): | |
| data = self.get_images_as_doc(image_data) | |
| for item in data: | |
| self.insert(item) | |