ml_research_agent / src /agent_hackathon /create_vector_db.py
shamik
feat: adding project files.
f896763 unverified
import json
from copy import deepcopy
from dotenv import find_dotenv, load_dotenv
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import Document
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.milvus import MilvusVectorStore
from llama_index.vector_stores.milvus.utils import BGEM3SparseEmbeddingFunction
from src.agent_hackathon.consts import PROJECT_ROOT_DIR
from src.agent_hackathon.logger import get_logger
logger = get_logger(log_name="create_vector_db", log_dir=PROJECT_ROOT_DIR / "logs")
class VectorDBCreator:
"""Handles creation of a Milvus vector database from arXiv data."""
def __init__(
self,
data_path: str,
db_uri: str,
embedding_model: str = "Qwen/Qwen3-Embedding-0.6B",
chunk_size: int = 20_000,
chunk_overlap: int = 0,
vector_dim: int = 1024,
insert_batch_size: int = 8192,
) -> None:
"""
Initialize the VectorDBCreator.
Args:
data_path: Path to the JSON data file.
db_uri: URI for the Milvus database.
embedding_model: Name of the embedding model.
chunk_size: Size of text chunks for splitting.
chunk_overlap: Overlap between text chunks.
vector_dim: Dimension of the embedding vectors.
insert_batch_size: Batch size for insertion.
"""
self.data_path = data_path
self.db_uri = db_uri
self.embedding_model = embedding_model
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.vector_dim = vector_dim
self.insert_batch_size = insert_batch_size
self.embed_model = HuggingFaceEmbedding(
model_name=self.embedding_model, device="cpu"
)
self.sent_splitter = SentenceSplitter(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
)
logger.info("VectorDBCreator initialized.")
def load_data(self) -> list[dict]:
"""
Load and return data from the JSON file.
Returns:
List of dictionaries containing arXiv data.
"""
logger.info(f"Loading data from {self.data_path}")
with open(file=self.data_path) as f:
data = json.load(fp=f)
logger.info("Data loaded successfully.")
return deepcopy(x=data)
def prepare_documents(self, data: list[dict]) -> list[Document]:
"""
Convert raw data into a list of Document objects.
Args:
data: List of dictionaries with arXiv data.
Returns:
List of Document objects.
"""
logger.info("Preparing documents from data.")
docs = [Document(text=d.pop("abstract"), metadata=d) for d in data]
logger.info(f"Prepared {len(docs)} documents.")
return docs
def create_vector_store(self) -> MilvusVectorStore:
"""
Create and return a MilvusVectorStore instance.
Returns:
Configured MilvusVectorStore.
"""
logger.info(f"Creating MilvusVectorStore at {self.db_uri}")
store = MilvusVectorStore(
uri=self.db_uri,
dim=self.vector_dim,
enable_sparse=True,
sparse_embedding_function=BGEM3SparseEmbeddingFunction(),
)
logger.info("MilvusVectorStore created.")
return store
def build_index(
self, docs_list: list[Document], vector_store: MilvusVectorStore
) -> VectorStoreIndex:
"""
Build and return a VectorStoreIndex from documents.
Args:
docs_list: List of Document objects.
vector_store: MilvusVectorStore instance.
Returns:
VectorStoreIndex object.
"""
logger.info("Building VectorStoreIndex.")
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(
documents=docs_list,
storage_context=storage_context,
embed_model=self.embed_model,
transformations=[self.sent_splitter],
show_progress=True,
insert_batch_size=self.insert_batch_size,
)
logger.info("VectorStoreIndex built.")
return index
def run(self) -> None:
"""
Execute the full pipeline: load data, prepare documents, create vector store, and build index.
"""
logger.info("Running full vector DB creation pipeline.")
data = self.load_data()
docs_list = self.prepare_documents(data=data)
vector_store = self.create_vector_store()
self.build_index(docs_list=docs_list, vector_store=vector_store)
logger.info("Pipeline finished.")
# if __name__ == "__main__":
# logger.info("Script started.")
# # Optionally load environment variables if needed
# _ = load_dotenv(dotenv_path=find_dotenv(raise_error_if_not_found=True))
# creator = VectorDBCreator(
# data_path=f"{PROJECT_ROOT_DIR}/data/cs_data_arxiv.json", db_uri="arxiv_docs.db"
# )
# creator.run()
# logger.info("Script finished.")