import argparse from dataloaders.langchain import FinanceBenchDataloader from langchain_huggingface import HuggingFaceEmbeddings from pymilvus import CollectionSchema, DataType, FieldSchema from rag_pipelines.embeddings import SparseEmbeddingsMilvus as SparseEmbeddings from rag_pipelines.unstructured import UnstructuredChunker, UnstructuredDocumentLoader from rag_pipelines.utils import dict_type from rag_pipelines.vectordb import MilvusVectorDB def parse_arguments() -> argparse.Namespace: """Parse command-line arguments. Returns: argparse.Namespace: Parsed command-line arguments. """ parser = argparse.ArgumentParser( description="Run the FinanceBench pipeline to load, process, chunk, embed, and index documents." ) # FinanceBench dataset parameters parser.add_argument( "--dataset_name", type=str, default="PatronusAI/financebench", help="HuggingFace dataset name.", ) parser.add_argument( "--split", type=str, default="train", help="Dataset split to use (e.g., 'train').", ) # PDF directory for unstructured document loader parser.add_argument( "--pdf_dir", type=str, default="pdfs/", help="Directory path containing PDF files.", ) # UnstructuredDocumentLoader parameters parser.add_argument( "--strategy", type=str, default="fast", help="Processing strategy for the unstructured document loader.", ) parser.add_argument( "--mode", type=str, default="elements", help="Extraction mode for the unstructured document loader.", ) # Milvus connection parameters parser.add_argument( "--milvus_uri", type=str, help="URI for the Milvus server.", ) parser.add_argument( "--milvus_token", type=str, help="Authentication token for Milvus.", ) parser.add_argument( "--collection_name", type=str, default="financebench", help="Name of the Milvus collection to create/use.", ) # Dense embedding model parameters parser.add_argument( "--dense_embedding_model", type=str, default="sentence-transformers/all-mpnet-base-v2", help="Model name for dense embeddings.", ) parser.add_argument( "--dense_model_kwargs", type=dict_type, default='{"device": "cpu", "trust_remote_code": true}', help="Keyword arguments for dense embeddings model initialization.", ) parser.add_argument( "--dense_encode_kwargs", type=dict_type, default='{"normalize_embeddings": true}', help="Keyword arguments for dense embeddings encoding.", ) # Sparse embedding model parameters parser.add_argument( "--sparse_embedding_model", type=str, default="Splade_PP_en_v1", help="Model name for sparse embeddings.", ) # Schema configuration parameters # Field names parser.add_argument( "--pk_field", type=str, default="doc_id", help="Name of the primary key field.", ) parser.add_argument( "--dense_field", type=str, default="dense_vector", help="Name of the dense vector field.", ) parser.add_argument( "--sparse_field", type=str, default="sparse_vector", help="Name of the sparse vector field.", ) parser.add_argument( "--text_field", type=str, default="text", help="Name of the text field.", ) parser.add_argument( "--metadata_field", type=str, default="metadata", help="Name of the metadata field.", ) parser.add_argument( "--dense_dim", type=int, default=768, help="Dimension of dense embeddings.", ) parser.add_argument( "--pk_max_length", type=int, default=100, help="Max length for the primary key field.", ) parser.add_argument( "--text_max_length", type=int, default=65535, help="Max length for the text field.", ) # Index parameters parser.add_argument( "--dense_index_params", type=dict_type, default='{"index_type": "FLAT", "metric_type": "IP"}', help="JSON string specifying dense index parameters.", ) parser.add_argument( "--sparse_index_params", type=dict_type, default='{"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"}', help="JSON string specifying sparse index parameters.", ) # Collection creation flag parser.add_argument( "--create_new_collection", action="store_true", help="Create a new collection or use existing. Defaults to False.", ) return parser.parse_args() def main() -> None: """Run the FinanceBench document processing pipeline. This function performs the following steps: 1. Loads the FinanceBench dataset. 2. Retrieves PDF documents from the specified directory. 3. Processes PDFs using the UnstructuredDocumentLoader. 4. Chunks documents using the UnstructuredChunker. 5. Generates dense and sparse embeddings with specified parameters. 6. Sets up a Milvus vector database and indexes the documents. """ args = parse_arguments() # Initialize FinanceBench dataloader and load the corpus PDFs dataloader = FinanceBenchDataloader( dataset_name=args.dataset_name, split=args.split, ) # Load and transform PDF documents from the provided directory unstructured_document_loader = UnstructuredDocumentLoader( strategy=args.strategy, mode=args.mode, ) # Chunk the documents using the UnstructuredChunker chunker = UnstructuredChunker() # Initialize dense and sparse embedding models with additional parameters dense_embeddings = HuggingFaceEmbeddings( model_name=args.dense_embedding_model, model_kwargs=args.dense_model_kwargs, encode_kwargs=args.dense_encode_kwargs, ) sparse_embeddings = SparseEmbeddings( model_name=args.sparse_embedding_model, ) # Define Milvus collection fields and schema pk_field = args.pk_field dense_field = args.dense_field sparse_field = args.sparse_field text_field = args.text_field metadata_field = args.metadata_field fields = [ FieldSchema( name=pk_field, dtype=DataType.VARCHAR, is_primary=True, auto_id=True, max_length=args.pk_max_length, ), FieldSchema(name=dense_field, dtype=DataType.FLOAT_VECTOR, dim=args.dense_dim), FieldSchema(name=sparse_field, dtype=DataType.SPARSE_FLOAT_VECTOR), FieldSchema(name=text_field, dtype=DataType.VARCHAR, max_length=args.text_max_length), FieldSchema(name=metadata_field, dtype=DataType.JSON), ] schema = CollectionSchema(fields=fields, enable_dynamic_field=False) # Initialize the Milvus vector database client milvus_vector_db = MilvusVectorDB( uri=args.milvus_uri, token=args.milvus_token, collection_name=args.collection_name, collection_schema=schema, dense_field=dense_field, sparse_field=sparse_field, text_field=text_field, metadata_field=metadata_field, dense_index_params=args.dense_index_params, sparse_index_params=args.sparse_index_params, create_new_collection=args.create_new_collection, ) # Add documents to the Milvus vector database dataloader.get_corpus_pdfs() documents = unstructured_document_loader.transform_documents(args.pdf_dir) chunked_documents = chunker.transform_documents(documents) milvus_vector_db.add_documents( documents=chunked_documents, dense_embedding_model=dense_embeddings, sparse_embedding_model=sparse_embeddings, ) if __name__ == "__main__": main()