File size: 5,278 Bytes
f896763 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import json
from copy import deepcopy
from dotenv import find_dotenv, load_dotenv
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import Document
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.milvus import MilvusVectorStore
from llama_index.vector_stores.milvus.utils import BGEM3SparseEmbeddingFunction
from src.agent_hackathon.consts import PROJECT_ROOT_DIR
from src.agent_hackathon.logger import get_logger
logger = get_logger(log_name="create_vector_db", log_dir=PROJECT_ROOT_DIR / "logs")
class VectorDBCreator:
"""Handles creation of a Milvus vector database from arXiv data."""
def __init__(
self,
data_path: str,
db_uri: str,
embedding_model: str = "Qwen/Qwen3-Embedding-0.6B",
chunk_size: int = 20_000,
chunk_overlap: int = 0,
vector_dim: int = 1024,
insert_batch_size: int = 8192,
) -> None:
"""
Initialize the VectorDBCreator.
Args:
data_path: Path to the JSON data file.
db_uri: URI for the Milvus database.
embedding_model: Name of the embedding model.
chunk_size: Size of text chunks for splitting.
chunk_overlap: Overlap between text chunks.
vector_dim: Dimension of the embedding vectors.
insert_batch_size: Batch size for insertion.
"""
self.data_path = data_path
self.db_uri = db_uri
self.embedding_model = embedding_model
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.vector_dim = vector_dim
self.insert_batch_size = insert_batch_size
self.embed_model = HuggingFaceEmbedding(
model_name=self.embedding_model, device="cpu"
)
self.sent_splitter = SentenceSplitter(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
)
logger.info("VectorDBCreator initialized.")
def load_data(self) -> list[dict]:
"""
Load and return data from the JSON file.
Returns:
List of dictionaries containing arXiv data.
"""
logger.info(f"Loading data from {self.data_path}")
with open(file=self.data_path) as f:
data = json.load(fp=f)
logger.info("Data loaded successfully.")
return deepcopy(x=data)
def prepare_documents(self, data: list[dict]) -> list[Document]:
"""
Convert raw data into a list of Document objects.
Args:
data: List of dictionaries with arXiv data.
Returns:
List of Document objects.
"""
logger.info("Preparing documents from data.")
docs = [Document(text=d.pop("abstract"), metadata=d) for d in data]
logger.info(f"Prepared {len(docs)} documents.")
return docs
def create_vector_store(self) -> MilvusVectorStore:
"""
Create and return a MilvusVectorStore instance.
Returns:
Configured MilvusVectorStore.
"""
logger.info(f"Creating MilvusVectorStore at {self.db_uri}")
store = MilvusVectorStore(
uri=self.db_uri,
dim=self.vector_dim,
enable_sparse=True,
sparse_embedding_function=BGEM3SparseEmbeddingFunction(),
)
logger.info("MilvusVectorStore created.")
return store
def build_index(
self, docs_list: list[Document], vector_store: MilvusVectorStore
) -> VectorStoreIndex:
"""
Build and return a VectorStoreIndex from documents.
Args:
docs_list: List of Document objects.
vector_store: MilvusVectorStore instance.
Returns:
VectorStoreIndex object.
"""
logger.info("Building VectorStoreIndex.")
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(
documents=docs_list,
storage_context=storage_context,
embed_model=self.embed_model,
transformations=[self.sent_splitter],
show_progress=True,
insert_batch_size=self.insert_batch_size,
)
logger.info("VectorStoreIndex built.")
return index
def run(self) -> None:
"""
Execute the full pipeline: load data, prepare documents, create vector store, and build index.
"""
logger.info("Running full vector DB creation pipeline.")
data = self.load_data()
docs_list = self.prepare_documents(data=data)
vector_store = self.create_vector_store()
self.build_index(docs_list=docs_list, vector_store=vector_store)
logger.info("Pipeline finished.")
# if __name__ == "__main__":
# logger.info("Script started.")
# # Optionally load environment variables if needed
# _ = load_dotenv(dotenv_path=find_dotenv(raise_error_if_not_found=True))
# creator = VectorDBCreator(
# data_path=f"{PROJECT_ROOT_DIR}/data/cs_data_arxiv.json", db_uri="arxiv_docs.db"
# )
# creator.run()
# logger.info("Script finished.")
|