Spaces:
Runtime error
Runtime error
| import os | |
| import shutil | |
| from typing import List | |
| import pandas as pd | |
| from pymilvus import MilvusClient, connections, FieldSchema, CollectionSchema, DataType, Collection | |
| import logging | |
| from backend.classes.vector_database.base_vector_database import VectorDatabaseConfig, VectorDatabase | |
| logger = logging.getLogger(__name__) | |
| class MilvusVectorDatabaseConfig(VectorDatabaseConfig): | |
| """Configuration for Milvus vector database.""" | |
| db_path: str | |
| collection_name: str | |
| vector_dimensions: int | |
| drop_if_exists: bool = True | |
| class Config: | |
| arbitrary_types_allowed = True | |
| class MilvusVectorDatabase(VectorDatabase): | |
| """Implementation of vector database using Milvus.""" | |
| def __init__(self, config: MilvusVectorDatabaseConfig): | |
| super().__init__(config) | |
| # Create database | |
| self.client = self.connect() | |
| self.create_collection(config.drop_if_exists) | |
| # # Create or get collection | |
| # schema = CollectionSchema(fields, description="Text embeddings collection") | |
| # self.collection:Collection = Collection(name=self.config.collection_name, schema=schema) | |
| def connect(self): | |
| logger.info(f"\nConnecting to Milvus at {self.config.db_path}...") | |
| client = MilvusClient(self.config.db_path) | |
| logger.info("Connected to Milvus.") | |
| return client | |
| def _define_schema(self) -> List[FieldSchema]: | |
| """ | |
| Defines the Milvus collection schema for hybrid search. | |
| - `id`: Primary key for unique chunk identification. | |
| - `text_content`: Stores the chunked text, suitable for keyword filtering using `LIKE` or equality. | |
| - `embedding`: Stores the dense vector embedding for similarity search. | |
| - `doc_metadata`: A JSON field to store additional, flexible metadata for filtering. | |
| """ | |
| fields = [ | |
| FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), | |
| FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=1024), | |
| FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=self.config.vector_dimensions), | |
| FieldSchema(name="metadata", dtype=DataType.JSON, description="Flexible JSON metadata for the document") | |
| ] | |
| return fields | |
| def create_collection(self, drop_if_exists: bool = True): | |
| """ | |
| Creates the Milvus collection with the defined schema and necessary indexes. | |
| Args: | |
| drop_if_exists (bool): If True, drops the collection if it already exists | |
| before creating a new one. Defaults to True. | |
| """ | |
| if drop_if_exists: # and self.client.has_collection(collection_name=self.config.collection_name): | |
| logger.info(f"Dropping existing collection '{self.config.collection_name}'...") | |
| self.client.drop_collection(collection_name=self.config.collection_name) | |
| # Create scalar index on 'text_content' for efficient filtering (e.g., using LIKE) | |
| logger.info(f"Creating scalar index on 'text_content' for filtering...") | |
| index_params = self.client.prepare_index_params() | |
| index_params.add_index( | |
| field_name="embedding", | |
| metric_type="COSINE", # Metric type is ignored for scalar indexes but required by API | |
| index_type="IVF_FLAT", # HNSW is a good general-purpose vector index | |
| params={"nlist": 128} | |
| ) | |
| fields = self._define_schema() | |
| milvus_schema = CollectionSchema( | |
| fields=fields, | |
| description="Hybrid search collection for Finance documents" # You can customize this description | |
| ) | |
| logger.info(f"Creating collection '{self.config.collection_name}'...") | |
| self.client.create_collection( | |
| collection_name=self.config.collection_name, | |
| schema=milvus_schema, | |
| index_params=index_params, | |
| dimension=self.config.vector_dimensions | |
| ) | |
| # # Create scalar index on 'text_content' for efficient filtering (e.g., using LIKE) | |
| # print(f"Creating scalar index on 'text' for filtering...") | |
| # self.client.create_index( | |
| # collection_name=self.config.collection_name, | |
| # field_name="text", | |
| # index_type="STL", # Segment Tree Index, suitable for VARCHAR filtering (equality, range, LIKE) | |
| # metric_type="COSINE", # Metric type is ignored for scalar indexes but required by API | |
| # index_params=index_params | |
| # ) | |
| def add_texts(self, df: pd.DataFrame, embeddings: list): | |
| """ | |
| Add texts and their embeddings to the collection. | |
| Args: | |
| df: DataFrame containing text data with columns | |
| embeddings: List of embeddings corresponding to each text | |
| """ | |
| # Prepare data | |
| data = [] | |
| for index, row in df.iterrows(): | |
| row["embedding"] = embeddings[index] | |
| data.append(row.to_dict()) | |
| # data = [ | |
| # df.text.tolist(), | |
| # embeddings, | |
| # df.metadata.tolist() | |
| # ] | |
| # | |
| # Insert data | |
| self.client.insert(collection_name=self.config.collection_name,data=data) | |
| def hybrid_search(self, query_embedding: list, query_text: str, limit: int = 5, | |
| text_weight: float = 0.4, embedding_weight: float = 0.6) -> list: | |
| """ | |
| Perform hybrid search combining text-based and vector similarity search. | |
| Args: | |
| query_embedding: Embedding vector for similarity search | |
| query_text: Text query for text-based search | |
| limit: Number of results to return | |
| text_weight: Weight for text-based search score | |
| embedding_weight: Weight for embedding similarity score | |
| Returns: | |
| List of search results with combined scores | |
| """ | |
| output_fields = ["text", "metadata"] | |
| # Vector similarity search | |
| search_results = self.client.search( | |
| collection_name=self.config.collection_name, | |
| data=[query_embedding], | |
| anns_field="embedding", | |
| param={"metric_type": "L2", "params": {"nprobe": 10}}, | |
| limit=limit * 2, # Get more candidates to combine with text search | |
| output_fields=output_fields | |
| ) | |
| # Process embedding results | |
| formatted_results = [] | |
| if search_results and search_results[0]: | |
| for hit in search_results[0]: | |
| result = { | |
| "id": hit['id'], | |
| "distance": hit['distance'], | |
| "text": hit.get('text', 'N/A'), | |
| "metadata": hit.get('metadata', {}) | |
| } | |
| # Add any other requested output fields | |
| for field in output_fields: | |
| if field not in result: # Avoid overwriting 'text' or 'metadata' if already handled | |
| result[field] = hit.get(field) | |
| formatted_results.append(result) | |
| return formatted_results | |
| def search_similar_texts(self, query_embedding: list, limit: int = 5): | |
| """ | |
| Search for similar texts based on embeddings. | |
| Args: | |
| query_embedding: Embedding vector to search for | |
| limit: Number of results to return | |
| Returns: | |
| List of similar texts and their distances | |
| """ | |
| output_fields = ["text"] | |
| search_results = self.client.search( | |
| collection_name=self.config.collection_name, | |
| data=query_embedding, | |
| anns_field="embedding", | |
| # param={"metric_type": "L2", "params": {"nprobe": 10}}, | |
| limit=limit, # Get more candidates to combine with text search | |
| output_fields=output_fields | |
| ) | |
| return [{ | |
| "text": result.get("text"), | |
| "distance": result["distance"] | |
| } for result in search_results[0]] | |
| def drop_collection(self): | |
| """Drop the collection.""" | |
| if os.path.exists(self.config.db_path): | |
| logger.info(f"Removing local Milvus Lite data directory: {self.config.db_path}...") | |
| shutil.rmtree(self.config.db_path) | |
| logger.info("Local data removed.") | |
| else: | |
| logger.info(f"Local data directory '{self.config.db_path}' not found, nothing to clean.") | |