File size: 8,590 Bytes
e68d535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import os
import shutil
from typing import List

import pandas as pd
from pymilvus import MilvusClient, connections, FieldSchema, CollectionSchema, DataType, Collection
import logging

from backend.classes.vector_database.base_vector_database import VectorDatabaseConfig, VectorDatabase

logger = logging.getLogger(__name__)


class MilvusVectorDatabaseConfig(VectorDatabaseConfig):
    """Configuration for Milvus vector database."""
    db_path: str
    collection_name: str
    vector_dimensions: int
    drop_if_exists: bool = True

    class Config:
        arbitrary_types_allowed = True


class MilvusVectorDatabase(VectorDatabase):
    """Implementation of vector database using Milvus."""
    def __init__(self, config: MilvusVectorDatabaseConfig):
        super().__init__(config)

        # Create database
        self.client = self.connect()

        self.create_collection(config.drop_if_exists)

        # # Create or get collection
        # schema = CollectionSchema(fields, description="Text embeddings collection")
        # self.collection:Collection = Collection(name=self.config.collection_name, schema=schema)

    def connect(self):
        logger.info(f"\nConnecting to Milvus at {self.config.db_path}...")
        client = MilvusClient(self.config.db_path)
        logger.info("Connected to Milvus.")
        return client

    def _define_schema(self) -> List[FieldSchema]:
        """
        Defines the Milvus collection schema for hybrid search.

        - `id`: Primary key for unique chunk identification.
        - `text_content`: Stores the chunked text, suitable for keyword filtering using `LIKE` or equality.
        - `embedding`: Stores the dense vector embedding for similarity search.
        - `doc_metadata`: A JSON field to store additional, flexible metadata for filtering.
        """
        fields = [
            FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
            FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=1024),
            FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=self.config.vector_dimensions),
            FieldSchema(name="metadata", dtype=DataType.JSON, description="Flexible JSON metadata for the document")
        ]
        return fields

    def create_collection(self, drop_if_exists: bool = True):
        """
                Creates the Milvus collection with the defined schema and necessary indexes.

                Args:
                    drop_if_exists (bool): If True, drops the collection if it already exists
                                           before creating a new one. Defaults to True.
                """
        if drop_if_exists: # and self.client.has_collection(collection_name=self.config.collection_name):
            logger.info(f"Dropping existing collection '{self.config.collection_name}'...")
            self.client.drop_collection(collection_name=self.config.collection_name)

            # Create scalar index on 'text_content' for efficient filtering (e.g., using LIKE)
            logger.info(f"Creating scalar index on 'text_content' for filtering...")
            index_params = self.client.prepare_index_params()
            index_params.add_index(
                field_name="embedding",
                metric_type="COSINE",  # Metric type is ignored for scalar indexes but required by API
                index_type="IVF_FLAT",  # HNSW is a good general-purpose vector index
                params={"nlist": 128}
            )

            fields = self._define_schema()
            milvus_schema = CollectionSchema(
                fields=fields,
                description="Hybrid search collection for Finance documents"  # You can customize this description
            )

            logger.info(f"Creating collection '{self.config.collection_name}'...")
            self.client.create_collection(
                collection_name=self.config.collection_name,
                schema=milvus_schema,
                index_params=index_params,
                dimension=self.config.vector_dimensions
            )

        # # Create scalar index on 'text_content' for efficient filtering (e.g., using LIKE)
        # print(f"Creating scalar index on 'text' for filtering...")
        # self.client.create_index(
        #     collection_name=self.config.collection_name,
        #     field_name="text",
        #     index_type="STL",  # Segment Tree Index, suitable for VARCHAR filtering (equality, range, LIKE)
        #     metric_type="COSINE",  # Metric type is ignored for scalar indexes but required by API
        #     index_params=index_params
        # )


    def add_texts(self, df: pd.DataFrame, embeddings: list):
        """
        Add texts and their embeddings to the collection.
        
        Args:
            df: DataFrame containing text data with columns
            embeddings: List of embeddings corresponding to each text
        """
        # Prepare data
        data = []
        for index, row in df.iterrows():
            row["embedding"] = embeddings[index]
            data.append(row.to_dict())

        # data = [
        #     df.text.tolist(),
        #     embeddings,
        #     df.metadata.tolist()
        # ]
        #
        # Insert data
        self.client.insert(collection_name=self.config.collection_name,data=data)

    def hybrid_search(self, query_embedding: list, query_text: str, limit: int = 5, 
                     text_weight: float = 0.4, embedding_weight: float = 0.6) -> list:
        """
        Perform hybrid search combining text-based and vector similarity search.
        
        Args:
            query_embedding: Embedding vector for similarity search
            query_text: Text query for text-based search
            limit: Number of results to return
            text_weight: Weight for text-based search score
            embedding_weight: Weight for embedding similarity score
            
        Returns:
            List of search results with combined scores
        """
        output_fields = ["text", "metadata"]

        # Vector similarity search
        search_results = self.client.search(
            collection_name=self.config.collection_name,
            data=[query_embedding],
            anns_field="embedding",
            param={"metric_type": "L2", "params": {"nprobe": 10}},
            limit=limit * 2,  # Get more candidates to combine with text search
            output_fields=output_fields
        )

        # Process embedding results
        formatted_results = []
        if search_results and search_results[0]:
            for hit in search_results[0]:
                result = {
                    "id": hit['id'],
                    "distance": hit['distance'],
                    "text": hit.get('text', 'N/A'),
                    "metadata": hit.get('metadata', {})
                }
                # Add any other requested output fields
                for field in output_fields:
                    if field not in result: # Avoid overwriting 'text' or 'metadata' if already handled
                        result[field] = hit.get(field)
                formatted_results.append(result)
        return formatted_results

    def search_similar_texts(self, query_embedding: list, limit: int = 5):
        """
        Search for similar texts based on embeddings.
        
        Args:
            query_embedding: Embedding vector to search for
            limit: Number of results to return
            
        Returns:
            List of similar texts and their distances
        """
        output_fields = ["text"]
        search_results = self.client.search(
            collection_name=self.config.collection_name,
            data=query_embedding,
            anns_field="embedding",
            # param={"metric_type": "L2", "params": {"nprobe": 10}},
            limit=limit,  # Get more candidates to combine with text search
            output_fields=output_fields
        )
        
        return [{
            "text": result.get("text"),
            "distance": result["distance"]
        } for result in search_results[0]]
        
    def drop_collection(self):
        """Drop the collection."""
        if os.path.exists(self.config.db_path):
            logger.info(f"Removing local Milvus Lite data directory: {self.config.db_path}...")
            shutil.rmtree(self.config.db_path)
            logger.info("Local data removed.")
        else:
            logger.info(f"Local data directory '{self.config.db_path}' not found, nothing to clean.")