import os import shutil from typing import List import pandas as pd from pymilvus import MilvusClient, connections, FieldSchema, CollectionSchema, DataType, Collection import logging from backend.classes.vector_database.base_vector_database import VectorDatabaseConfig, VectorDatabase logger = logging.getLogger(__name__) class MilvusVectorDatabaseConfig(VectorDatabaseConfig): """Configuration for Milvus vector database.""" db_path: str collection_name: str vector_dimensions: int drop_if_exists: bool = True class Config: arbitrary_types_allowed = True class MilvusVectorDatabase(VectorDatabase): """Implementation of vector database using Milvus.""" def __init__(self, config: MilvusVectorDatabaseConfig): super().__init__(config) # Create database self.client = self.connect() self.create_collection(config.drop_if_exists) # # Create or get collection # schema = CollectionSchema(fields, description="Text embeddings collection") # self.collection:Collection = Collection(name=self.config.collection_name, schema=schema) def connect(self): logger.info(f"\nConnecting to Milvus at {self.config.db_path}...") client = MilvusClient(self.config.db_path) logger.info("Connected to Milvus.") return client def _define_schema(self) -> List[FieldSchema]: """ Defines the Milvus collection schema for hybrid search. - `id`: Primary key for unique chunk identification. - `text_content`: Stores the chunked text, suitable for keyword filtering using `LIKE` or equality. - `embedding`: Stores the dense vector embedding for similarity search. - `doc_metadata`: A JSON field to store additional, flexible metadata for filtering. """ fields = [ FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=1024), FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=self.config.vector_dimensions), FieldSchema(name="metadata", dtype=DataType.JSON, description="Flexible JSON metadata for the document") ] return fields def create_collection(self, drop_if_exists: bool = True): """ Creates the Milvus collection with the defined schema and necessary indexes. Args: drop_if_exists (bool): If True, drops the collection if it already exists before creating a new one. Defaults to True. """ if drop_if_exists: # and self.client.has_collection(collection_name=self.config.collection_name): logger.info(f"Dropping existing collection '{self.config.collection_name}'...") self.client.drop_collection(collection_name=self.config.collection_name) # Create scalar index on 'text_content' for efficient filtering (e.g., using LIKE) logger.info(f"Creating scalar index on 'text_content' for filtering...") index_params = self.client.prepare_index_params() index_params.add_index( field_name="embedding", metric_type="COSINE", # Metric type is ignored for scalar indexes but required by API index_type="IVF_FLAT", # HNSW is a good general-purpose vector index params={"nlist": 128} ) fields = self._define_schema() milvus_schema = CollectionSchema( fields=fields, description="Hybrid search collection for Finance documents" # You can customize this description ) logger.info(f"Creating collection '{self.config.collection_name}'...") self.client.create_collection( collection_name=self.config.collection_name, schema=milvus_schema, index_params=index_params, dimension=self.config.vector_dimensions ) # # Create scalar index on 'text_content' for efficient filtering (e.g., using LIKE) # print(f"Creating scalar index on 'text' for filtering...") # self.client.create_index( # collection_name=self.config.collection_name, # field_name="text", # index_type="STL", # Segment Tree Index, suitable for VARCHAR filtering (equality, range, LIKE) # metric_type="COSINE", # Metric type is ignored for scalar indexes but required by API # index_params=index_params # ) def add_texts(self, df: pd.DataFrame, embeddings: list): """ Add texts and their embeddings to the collection. Args: df: DataFrame containing text data with columns embeddings: List of embeddings corresponding to each text """ # Prepare data data = [] for index, row in df.iterrows(): row["embedding"] = embeddings[index] data.append(row.to_dict()) # data = [ # df.text.tolist(), # embeddings, # df.metadata.tolist() # ] # # Insert data self.client.insert(collection_name=self.config.collection_name,data=data) def hybrid_search(self, query_embedding: list, query_text: str, limit: int = 5, text_weight: float = 0.4, embedding_weight: float = 0.6) -> list: """ Perform hybrid search combining text-based and vector similarity search. Args: query_embedding: Embedding vector for similarity search query_text: Text query for text-based search limit: Number of results to return text_weight: Weight for text-based search score embedding_weight: Weight for embedding similarity score Returns: List of search results with combined scores """ output_fields = ["text", "metadata"] # Vector similarity search search_results = self.client.search( collection_name=self.config.collection_name, data=[query_embedding], anns_field="embedding", param={"metric_type": "L2", "params": {"nprobe": 10}}, limit=limit * 2, # Get more candidates to combine with text search output_fields=output_fields ) # Process embedding results formatted_results = [] if search_results and search_results[0]: for hit in search_results[0]: result = { "id": hit['id'], "distance": hit['distance'], "text": hit.get('text', 'N/A'), "metadata": hit.get('metadata', {}) } # Add any other requested output fields for field in output_fields: if field not in result: # Avoid overwriting 'text' or 'metadata' if already handled result[field] = hit.get(field) formatted_results.append(result) return formatted_results def search_similar_texts(self, query_embedding: list, limit: int = 5): """ Search for similar texts based on embeddings. Args: query_embedding: Embedding vector to search for limit: Number of results to return Returns: List of similar texts and their distances """ output_fields = ["text"] search_results = self.client.search( collection_name=self.config.collection_name, data=query_embedding, anns_field="embedding", # param={"metric_type": "L2", "params": {"nprobe": 10}}, limit=limit, # Get more candidates to combine with text search output_fields=output_fields ) return [{ "text": result.get("text"), "distance": result["distance"] } for result in search_results[0]] def drop_collection(self): """Drop the collection.""" if os.path.exists(self.config.db_path): logger.info(f"Removing local Milvus Lite data directory: {self.config.db_path}...") shutil.rmtree(self.config.db_path) logger.info("Local data removed.") else: logger.info(f"Local data directory '{self.config.db_path}' not found, nothing to clean.")