Demos / backend /classes /vector_database /milvus_vector_database.py
nikhile-galileo's picture
Added G2.0 changes
753e3c5
raw
history blame
8.57 kB
import os
import shutil
from typing import List
import pandas as pd
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType
import logging
from backend.classes.vector_database.base_vector_database import VectorDatabaseConfig, VectorDatabase
logger = logging.getLogger(__name__)
class MilvusVectorDatabaseConfig(VectorDatabaseConfig):
"""Configuration for Milvus vector database."""
db_path: str
collection_name: str
vector_dimensions: int
drop_if_exists: bool = True
class Config:
arbitrary_types_allowed = True
class MilvusVectorDatabase(VectorDatabase):
"""Implementation of vector database using Milvus."""
def __init__(self, config: MilvusVectorDatabaseConfig):
super().__init__(config)
# Create database
self.client = self.connect()
self.create_collection(config.drop_if_exists)
# # Create or get collection
# schema = CollectionSchema(fields, description="Text embeddings collection")
# self.collection:Collection = Collection(name=self.config.collection_name, schema=schema)
def connect(self):
logger.info(f"\nConnecting to Milvus at {self.config.db_path}...")
client = MilvusClient(self.config.db_path)
logger.info("Connected to Milvus.")
return client
def _define_schema(self) -> List[FieldSchema]:
"""
Defines the Milvus collection schema for hybrid search.
- `id`: Primary key for unique chunk identification.
- `text_content`: Stores the chunked text, suitable for keyword filtering using `LIKE` or equality.
- `embedding`: Stores the dense vector embedding for similarity search.
- `doc_metadata`: A JSON field to store additional, flexible metadata for filtering.
"""
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=1024),
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=self.config.vector_dimensions),
FieldSchema(name="metadata", dtype=DataType.JSON, description="Flexible JSON metadata for the document")
]
return fields
def create_collection(self, drop_if_exists: bool = True):
"""
Creates the Milvus collection with the defined schema and necessary indexes.
Args:
drop_if_exists (bool): If True, drops the collection if it already exists
before creating a new one. Defaults to True.
"""
if drop_if_exists: # and self.client.has_collection(collection_name=self.config.collection_name):
logger.info(f"Dropping existing collection '{self.config.collection_name}'...")
self.client.drop_collection(collection_name=self.config.collection_name)
# Create scalar index on 'text_content' for efficient filtering (e.g., using LIKE)
logger.info(f"Creating scalar index on 'text_content' for filtering...")
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="embedding",
metric_type="COSINE", # Metric type is ignored for scalar indexes but required by API
index_type="IVF_FLAT", # HNSW is a good general-purpose vector index
params={"nlist": 128}
)
fields = self._define_schema()
milvus_schema = CollectionSchema(
fields=fields,
description="Hybrid search collection for Finance documents" # You can customize this description
)
logger.info(f"Creating collection '{self.config.collection_name}'...")
self.client.create_collection(
collection_name=self.config.collection_name,
schema=milvus_schema,
index_params=index_params,
dimension=self.config.vector_dimensions
)
# # Create scalar index on 'text_content' for efficient filtering (e.g., using LIKE)
# print(f"Creating scalar index on 'text' for filtering...")
# self.client.create_index(
# collection_name=self.config.collection_name,
# field_name="text",
# index_type="STL", # Segment Tree Index, suitable for VARCHAR filtering (equality, range, LIKE)
# metric_type="COSINE", # Metric type is ignored for scalar indexes but required by API
# index_params=index_params
# )
def add_texts(self, df: pd.DataFrame, embeddings: list):
"""
Add texts and their embeddings to the collection.
Args:
df: DataFrame containing text data with columns
embeddings: List of embeddings corresponding to each text
"""
# Prepare data
data = []
for index, row in df.iterrows():
row["embedding"] = embeddings[index]
data.append(row.to_dict())
# data = [
# df.text.tolist(),
# embeddings,
# df.metadata.tolist()
# ]
#
# Insert data
self.client.insert(collection_name=self.config.collection_name,data=data)
def hybrid_search(self, query_embedding: list, query_text: str, limit: int = 5,
text_weight: float = 0.4, embedding_weight: float = 0.6) -> list:
"""
Perform hybrid search combining text-based and vector similarity search.
Args:
query_embedding: Embedding vector for similarity search
query_text: Text query for text-based search
limit: Number of results to return
text_weight: Weight for text-based search score
embedding_weight: Weight for embedding similarity score
Returns:
List of search results with combined scores
"""
output_fields = ["text", "metadata"]
# Vector similarity search
search_results = self.client.search(
collection_name=self.config.collection_name,
data=[query_embedding],
anns_field="embedding",
param={"metric_type": "L2", "params": {"nprobe": 10}},
limit=limit * 2, # Get more candidates to combine with text search
output_fields=output_fields
)
# Process embedding results
formatted_results = []
if search_results and search_results[0]:
for hit in search_results[0]:
result = {
"id": hit['id'],
"distance": hit['distance'],
"text": hit.get('text', 'N/A'),
"metadata": hit.get('metadata', {})
}
# Add any other requested output fields
for field in output_fields:
if field not in result: # Avoid overwriting 'text' or 'metadata' if already handled
result[field] = hit.get(field)
formatted_results.append(result)
return formatted_results
def search_similar_texts(self, query_embedding: list, limit: int = 5):
"""
Search for similar texts based on embeddings.
Args:
query_embedding: Embedding vector to search for
limit: Number of results to return
Returns:
List of similar texts and their distances
"""
output_fields = ["text"]
search_results = self.client.search(
collection_name=self.config.collection_name,
data=query_embedding,
anns_field="embedding",
# param={"metric_type": "L2", "params": {"nprobe": 10}},
limit=limit, # Get more candidates to combine with text search
output_fields=output_fields
)
return [{
"text": result.get("text"),
"distance": result["distance"]
} for result in search_results[0]]
def drop_collection(self):
"""Drop the collection."""
if os.path.exists(self.config.db_path):
logger.info(f"Removing local Milvus Lite data directory: {self.config.db_path}...")
shutil.rmtree(self.config.db_path)
logger.info("Local data removed.")
else:
logger.info(f"Local data directory '{self.config.db_path}' not found, nothing to clean.")