test-ragp / scripts /indexing_financebench_milvus.py
awinml's picture
Upload 11 files
6c5ce7a verified
import argparse
from dataloaders.langchain import FinanceBenchDataloader
from langchain_huggingface import HuggingFaceEmbeddings
from pymilvus import CollectionSchema, DataType, FieldSchema
from rag_pipelines.embeddings import SparseEmbeddingsMilvus as SparseEmbeddings
from rag_pipelines.unstructured import UnstructuredChunker, UnstructuredDocumentLoader
from rag_pipelines.utils import dict_type
from rag_pipelines.vectordb import MilvusVectorDB
def parse_arguments() -> argparse.Namespace:
"""Parse command-line arguments.
Returns:
argparse.Namespace: Parsed command-line arguments.
"""
parser = argparse.ArgumentParser(
description="Run the FinanceBench pipeline to load, process, chunk, embed, and index documents."
)
# FinanceBench dataset parameters
parser.add_argument(
"--dataset_name",
type=str,
default="PatronusAI/financebench",
help="HuggingFace dataset name.",
)
parser.add_argument(
"--split",
type=str,
default="train",
help="Dataset split to use (e.g., 'train').",
)
# PDF directory for unstructured document loader
parser.add_argument(
"--pdf_dir",
type=str,
default="pdfs/",
help="Directory path containing PDF files.",
)
# UnstructuredDocumentLoader parameters
parser.add_argument(
"--strategy",
type=str,
default="fast",
help="Processing strategy for the unstructured document loader.",
)
parser.add_argument(
"--mode",
type=str,
default="elements",
help="Extraction mode for the unstructured document loader.",
)
# Milvus connection parameters
parser.add_argument(
"--milvus_uri",
type=str,
help="URI for the Milvus server.",
)
parser.add_argument(
"--milvus_token",
type=str,
help="Authentication token for Milvus.",
)
parser.add_argument(
"--collection_name",
type=str,
default="financebench",
help="Name of the Milvus collection to create/use.",
)
# Dense embedding model parameters
parser.add_argument(
"--dense_embedding_model",
type=str,
default="sentence-transformers/all-mpnet-base-v2",
help="Model name for dense embeddings.",
)
parser.add_argument(
"--dense_model_kwargs",
type=dict_type,
default='{"device": "cpu", "trust_remote_code": true}',
help="Keyword arguments for dense embeddings model initialization.",
)
parser.add_argument(
"--dense_encode_kwargs",
type=dict_type,
default='{"normalize_embeddings": true}',
help="Keyword arguments for dense embeddings encoding.",
)
# Sparse embedding model parameters
parser.add_argument(
"--sparse_embedding_model",
type=str,
default="Splade_PP_en_v1",
help="Model name for sparse embeddings.",
)
# Schema configuration parameters
# Field names
parser.add_argument(
"--pk_field",
type=str,
default="doc_id",
help="Name of the primary key field.",
)
parser.add_argument(
"--dense_field",
type=str,
default="dense_vector",
help="Name of the dense vector field.",
)
parser.add_argument(
"--sparse_field",
type=str,
default="sparse_vector",
help="Name of the sparse vector field.",
)
parser.add_argument(
"--text_field",
type=str,
default="text",
help="Name of the text field.",
)
parser.add_argument(
"--metadata_field",
type=str,
default="metadata",
help="Name of the metadata field.",
)
parser.add_argument(
"--dense_dim",
type=int,
default=768,
help="Dimension of dense embeddings.",
)
parser.add_argument(
"--pk_max_length",
type=int,
default=100,
help="Max length for the primary key field.",
)
parser.add_argument(
"--text_max_length",
type=int,
default=65535,
help="Max length for the text field.",
)
# Index parameters
parser.add_argument(
"--dense_index_params",
type=dict_type,
default='{"index_type": "FLAT", "metric_type": "IP"}',
help="JSON string specifying dense index parameters.",
)
parser.add_argument(
"--sparse_index_params",
type=dict_type,
default='{"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"}',
help="JSON string specifying sparse index parameters.",
)
# Collection creation flag
parser.add_argument(
"--create_new_collection",
action="store_true",
help="Create a new collection or use existing. Defaults to False.",
)
return parser.parse_args()
def main() -> None:
"""Run the FinanceBench document processing pipeline.
This function performs the following steps:
1. Loads the FinanceBench dataset.
2. Retrieves PDF documents from the specified directory.
3. Processes PDFs using the UnstructuredDocumentLoader.
4. Chunks documents using the UnstructuredChunker.
5. Generates dense and sparse embeddings with specified parameters.
6. Sets up a Milvus vector database and indexes the documents.
"""
args = parse_arguments()
# Initialize FinanceBench dataloader and load the corpus PDFs
dataloader = FinanceBenchDataloader(
dataset_name=args.dataset_name,
split=args.split,
)
# Load and transform PDF documents from the provided directory
unstructured_document_loader = UnstructuredDocumentLoader(
strategy=args.strategy,
mode=args.mode,
)
# Chunk the documents using the UnstructuredChunker
chunker = UnstructuredChunker()
# Initialize dense and sparse embedding models with additional parameters
dense_embeddings = HuggingFaceEmbeddings(
model_name=args.dense_embedding_model,
model_kwargs=args.dense_model_kwargs,
encode_kwargs=args.dense_encode_kwargs,
)
sparse_embeddings = SparseEmbeddings(
model_name=args.sparse_embedding_model,
)
# Define Milvus collection fields and schema
pk_field = args.pk_field
dense_field = args.dense_field
sparse_field = args.sparse_field
text_field = args.text_field
metadata_field = args.metadata_field
fields = [
FieldSchema(
name=pk_field,
dtype=DataType.VARCHAR,
is_primary=True,
auto_id=True,
max_length=args.pk_max_length,
),
FieldSchema(name=dense_field, dtype=DataType.FLOAT_VECTOR, dim=args.dense_dim),
FieldSchema(name=sparse_field, dtype=DataType.SPARSE_FLOAT_VECTOR),
FieldSchema(name=text_field, dtype=DataType.VARCHAR, max_length=args.text_max_length),
FieldSchema(name=metadata_field, dtype=DataType.JSON),
]
schema = CollectionSchema(fields=fields, enable_dynamic_field=False)
# Initialize the Milvus vector database client
milvus_vector_db = MilvusVectorDB(
uri=args.milvus_uri,
token=args.milvus_token,
collection_name=args.collection_name,
collection_schema=schema,
dense_field=dense_field,
sparse_field=sparse_field,
text_field=text_field,
metadata_field=metadata_field,
dense_index_params=args.dense_index_params,
sparse_index_params=args.sparse_index_params,
create_new_collection=args.create_new_collection,
)
# Add documents to the Milvus vector database
dataloader.get_corpus_pdfs()
documents = unstructured_document_loader.transform_documents(args.pdf_dir)
chunked_documents = chunker.transform_documents(documents)
milvus_vector_db.add_documents(
documents=chunked_documents,
dense_embedding_model=dense_embeddings,
sparse_embedding_model=sparse_embeddings,
)
if __name__ == "__main__":
main()