| | from pymilvus import MilvusClient, DataType |
| | import numpy as np |
| | import concurrent.futures |
| | from pymilvus import Collection |
| | import os |
| |
|
| | class MilvusManager: |
| | def __init__(self, milvus_uri, collection_name, create_collection, dim=128): |
| | |
| | |
| | import dotenv |
| | |
| | dotenv_file = dotenv.find_dotenv() |
| | dotenv.load_dotenv(dotenv_file) |
| |
|
| | self.client = MilvusClient(uri="http://localhost:19530", token="root:Milvus") |
| | self.collection_name = collection_name |
| | self.dim = dim |
| |
|
| | if self.client.has_collection(collection_name=self.collection_name): |
| | self.client.load_collection(collection_name=self.collection_name) |
| | print("Loaded existing collection.") |
| | elif create_collection: |
| | self.create_collection() |
| | self.create_index() |
| |
|
| | def create_collection(self): |
| | if self.client.has_collection(collection_name=self.collection_name): |
| | print("Collection already exists.") |
| | return |
| |
|
| | schema = self.client.create_schema( |
| | auto_id=True, |
| | enable_dynamic_fields=True, |
| | ) |
| | schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True) |
| | schema.add_field( |
| | field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim |
| | ) |
| | schema.add_field(field_name="seq_id", datatype=DataType.INT16) |
| | schema.add_field(field_name="doc_id", datatype=DataType.INT64) |
| | schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535) |
| |
|
| | self.client.create_collection( |
| | collection_name=self.collection_name, schema=schema |
| | ) |
| |
|
| | def create_index(self): |
| | index_params = self.client.prepare_index_params() |
| |
|
| | index_params.add_index( |
| | field_name="vector", |
| | index_name="vector_index", |
| | index_type="HNSW", |
| | metric_type=os.environ["metrictype"], |
| | params={ |
| | "M": int(os.environ["mnum"]), |
| | "efConstruction": int(os.environ["efnum"]), |
| | }, |
| | ) |
| |
|
| | self.client.create_index( |
| | collection_name=self.collection_name, index_params=index_params, sync=True |
| | ) |
| |
|
| | def search(self, data, topk): |
| | |
| | collections = self.client.list_collections() |
| | |
| | |
| | search_params = {"metric_type": os.environ["metrictype"], "params": {}} |
| | |
| | |
| | doc_collection_pairs = set() |
| |
|
| | |
| | for collection in collections: |
| | self.client.load_collection(collection_name=collection) |
| | print("collection loaded:"+ collection) |
| | results = self.client.search( |
| | collection, |
| | data, |
| | limit=int(os.environ["topk"]), |
| | output_fields=["vector", "seq_id", "doc_id"], |
| | search_params=search_params, |
| | ) |
| | |
| | for r_id in range(len(results)): |
| | for r in range(len(results[r_id])): |
| | doc_id = results[r_id][r]["entity"]["doc_id"] |
| | doc_collection_pairs.add((doc_id, collection)) |
| |
|
| | scores = [] |
| |
|
| | def rerank_single_doc(doc_id, data, client, collection_name): |
| | |
| | doc_colbert_vecs = client.query( |
| | collection_name=collection_name, |
| | filter=f"doc_id in [{doc_id}, {doc_id + 1}]", |
| | output_fields=["seq_id", "vector", "doc"], |
| | limit=16380, |
| | ) |
| | |
| | doc_vecs = np.vstack( |
| | [doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))] |
| | ) |
| | |
| | score = np.dot(data, doc_vecs.T).max(1).sum() |
| | return (score, doc_id, collection_name) |
| |
|
| | |
| | with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor: |
| | futures = { |
| | executor.submit(rerank_single_doc, doc_id, data, self.client, collection): (doc_id, collection) |
| | for doc_id, collection in doc_collection_pairs |
| | } |
| | for future in concurrent.futures.as_completed(futures): |
| | score, doc_id, collection = future.result() |
| | scores.append((score, doc_id, collection)) |
| | |
| | |
| | |
| | scores.sort(key=lambda x: x[0], reverse=True) |
| | |
| | self.client.release_collection(collection_name=collection) |
| | |
| | return scores[:topk] if len(scores) >= topk else scores |
| | """ |
| | search_params = {"metric_type": "IP", "params": {}} |
| | results = self.client.search( |
| | self.collection_name, |
| | data, |
| | limit=50, |
| | output_fields=["vector", "seq_id", "doc_id"], |
| | search_params=search_params, |
| | ) |
| | doc_ids = {result["entity"]["doc_id"] for result in results[0]} |
| | |
| | scores = [] |
| | |
| | def rerank_single_doc(doc_id, data, client, collection_name): |
| | doc_colbert_vecs = client.query( |
| | collection_name=collection_name, |
| | filter=f"doc_id in [{doc_id}, {doc_id + 1}]", |
| | output_fields=["seq_id", "vector", "doc"], |
| | limit=1000, |
| | ) |
| | doc_vecs = np.vstack( |
| | [doc["vector"] for doc in doc_colbert_vecs] |
| | ) |
| | score = np.dot(data, doc_vecs.T).max(1).sum() |
| | return score, doc_id |
| | |
| | with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor: |
| | futures = { |
| | executor.submit( |
| | rerank_single_doc, doc_id, data, self.client, self.collection_name |
| | ): doc_id |
| | for doc_id in doc_ids |
| | } |
| | for future in concurrent.futures.as_completed(futures): |
| | score, doc_id = future.result() |
| | scores.append((score, doc_id)) |
| | |
| | scores.sort(key=lambda x: x[0], reverse=True) |
| | return scores[:topk] |
| | """ |
| |
|
| | def insert(self, data): |
| | colbert_vecs = data["colbert_vecs"] |
| | seq_length = len(colbert_vecs) |
| | doc_ids = [data["doc_id"]] * seq_length |
| | seq_ids = list(range(seq_length)) |
| | docs = [""] * seq_length |
| | docs[0] = data["filepath"] |
| |
|
| | self.client.insert( |
| | self.collection_name, |
| | [ |
| | { |
| | "vector": colbert_vecs[i], |
| | "seq_id": seq_ids[i], |
| | "doc_id": doc_ids[i], |
| | "doc": docs[i], |
| | } |
| | for i in range(seq_length) |
| | ], |
| | ) |
| |
|
| | def get_images_as_doc(self, images_with_vectors): |
| | return [ |
| | { |
| | "colbert_vecs": image["colbert_vecs"], |
| | "doc_id": idx, |
| | "filepath": image["filepath"], |
| | } |
| | for idx, image in enumerate(images_with_vectors) |
| | ] |
| |
|
| | def insert_images_data(self, image_data): |
| | data = self.get_images_as_doc(image_data) |
| | for item in data: |
| | self.insert(item) |
| |
|