Spaces:
Paused
Paused
| # milvus.py | |
| from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility | |
| import pandas as pd | |
| import os | |
| import sys | |
| from sentence_transformers import SentenceTransformer | |
| import time | |
| # Default Milvus connection details | |
| DEFAULT_MILVUS_HOST = 'localhost' | |
| DEFAULT_MILVUS_PORT = '19530' | |
| DEFAULT_COLLECTION_NAME = 'document_collection' | |
| DEFAULT_DIMENSION = 384 # Adjust based on your embedding model | |
| DEFAULT_MAX_RETRIES = 3 | |
| DEFAULT_RETRY_DELAY = 5 # seconds | |
| # Embedding model | |
| model = SentenceTransformer('all-MiniLM-L6-v2') | |
| def create_milvus_collection(host, port, collection_name, dimension): | |
| """ | |
| Creates a new Milvus collection if it doesn't exist. | |
| """ | |
| if not utility.has_collection(collection_name): | |
| fields = [ | |
| FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), | |
| FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=500), | |
| FieldSchema(name="content_vector", dtype=DataType.FLOAT_VECTOR, dim=dimension) | |
| ] | |
| schema = CollectionSchema(fields, "Document Vector Store") | |
| collection = Collection(collection_name, schema, consistency_level="Strong") | |
| index_params = { | |
| "metric_type": "L2", | |
| "index_type": "IVF_FLAT", | |
| "params": {"nlist": 1024} | |
| } | |
| collection.create_index(field_name="content_vector", index_params=index_params) | |
| print(f"Collection {collection_name} created and index built.") | |
| else: | |
| print(f"Collection {collection_name} already exists.") | |
| def load_data_to_milvus(host, port, collection_name): | |
| """ | |
| Loads data from the DataFrame into Milvus, using sentence embeddings. | |
| """ | |
| extraction_dir = "extraction" | |
| pkl_files = [f for f in os.listdir(extraction_dir) if f.endswith('.pkl')] | |
| if not pkl_files: | |
| print("No .pkl files found in the 'extraction' directory.") | |
| return | |
| df_path = os.path.join(extraction_dir, pkl_files[0]) | |
| df = pd.read_pickle(df_path) | |
| # Generate sentence embeddings | |
| df['content_vector'] = df['content'].apply(lambda x: model.encode(x).tolist()) | |
| data_to_insert = [ | |
| df['path'].tolist(), | |
| df['content_vector'].tolist() | |
| ] | |
| collection = Collection(collection_name) | |
| collection.insert(data_to_insert) | |
| collection.flush() | |
| print(f"Data from {df_path} loaded into Milvus collection {collection_name}.") | |
| def connect_to_milvus(host, port, max_retries, retry_delay): | |
| """Connects to Milvus with retries.""" | |
| retries = 0 | |
| while retries < max_retries: | |
| try: | |
| connections.connect(host=host, port=port) | |
| print(f"Successfully connected to Milvus at {host}:{port}") | |
| return True | |
| except Exception as e: | |
| print(f"Error connecting to Milvus: {e}") | |
| retries += 1 | |
| if retries < max_retries: | |
| print(f"Retrying in {retry_delay} seconds...") | |
| time.sleep(retry_delay) | |
| else: | |
| print("Max retries reached. Could not connect to Milvus.") | |
| return False | |
| def initialize_milvus(host, port, collection_name, dimension, max_retries, retry_delay): | |
| """Initializes Milvus with parameters. | |
| Returns: | |
| True if successfully connected and initialized, False otherwise. | |
| """ | |
| if connect_to_milvus(host, port, max_retries, retry_delay): | |
| try: | |
| create_milvus_collection(host, port, collection_name, dimension) | |
| load_data_to_milvus(host, port, collection_name) | |
| connections.disconnect(alias='default') | |
| return True # Return True if everything is successful | |
| except Exception as e: | |
| print(f"Error during initialization: {e}") | |
| return False # Return False if any error occurs during collection creation or data loading | |
| else: | |
| return False # Return False if connection failed | |
| if __name__ == "__main__": | |
| # Use default values or environment variables if available | |
| milvus_host = os.environ.get('MILVUS_HOST', DEFAULT_MILVUS_HOST) | |
| milvus_port = os.environ.get('MILVUS_PORT', DEFAULT_MILVUS_PORT) | |
| collection_name = os.environ.get('COLLECTION_NAME', DEFAULT_COLLECTION_NAME) | |
| dimension = int(os.environ.get('DIMENSION', DEFAULT_DIMENSION)) | |
| max_retries = int(os.environ.get('MAX_RETRIES', DEFAULT_MAX_RETRIES)) | |
| retry_delay = int(os.environ.get('RETRY_DELAY', DEFAULT_RETRY_DELAY)) | |
| initialize_milvus(milvus_host, milvus_port, collection_name, dimension, max_retries, retry_delay) |