File size: 5,278 Bytes
f896763
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import json
from copy import deepcopy

from dotenv import find_dotenv, load_dotenv
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import Document
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.milvus import MilvusVectorStore
from llama_index.vector_stores.milvus.utils import BGEM3SparseEmbeddingFunction

from src.agent_hackathon.consts import PROJECT_ROOT_DIR
from src.agent_hackathon.logger import get_logger

logger = get_logger(log_name="create_vector_db", log_dir=PROJECT_ROOT_DIR / "logs")


class VectorDBCreator:
    """Handles creation of a Milvus vector database from arXiv data."""

    def __init__(
        self,
        data_path: str,
        db_uri: str,
        embedding_model: str = "Qwen/Qwen3-Embedding-0.6B",
        chunk_size: int = 20_000,
        chunk_overlap: int = 0,
        vector_dim: int = 1024,
        insert_batch_size: int = 8192,
    ) -> None:
        """
        Initialize the VectorDBCreator.

        Args:
            data_path: Path to the JSON data file.
            db_uri: URI for the Milvus database.
            embedding_model: Name of the embedding model.
            chunk_size: Size of text chunks for splitting.
            chunk_overlap: Overlap between text chunks.
            vector_dim: Dimension of the embedding vectors.
            insert_batch_size: Batch size for insertion.
        """
        self.data_path = data_path
        self.db_uri = db_uri
        self.embedding_model = embedding_model
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.vector_dim = vector_dim
        self.insert_batch_size = insert_batch_size
        self.embed_model = HuggingFaceEmbedding(
            model_name=self.embedding_model, device="cpu"
        )
        self.sent_splitter = SentenceSplitter(
            chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
        )
        logger.info("VectorDBCreator initialized.")

    def load_data(self) -> list[dict]:
        """
        Load and return data from the JSON file.

        Returns:
            List of dictionaries containing arXiv data.
        """
        logger.info(f"Loading data from {self.data_path}")
        with open(file=self.data_path) as f:
            data = json.load(fp=f)
        logger.info("Data loaded successfully.")
        return deepcopy(x=data)

    def prepare_documents(self, data: list[dict]) -> list[Document]:
        """
        Convert raw data into a list of Document objects.

        Args:
            data: List of dictionaries with arXiv data.

        Returns:
            List of Document objects.
        """
        logger.info("Preparing documents from data.")
        docs = [Document(text=d.pop("abstract"), metadata=d) for d in data]
        logger.info(f"Prepared {len(docs)} documents.")
        return docs

    def create_vector_store(self) -> MilvusVectorStore:
        """
        Create and return a MilvusVectorStore instance.

        Returns:
            Configured MilvusVectorStore.
        """
        logger.info(f"Creating MilvusVectorStore at {self.db_uri}")
        store = MilvusVectorStore(
            uri=self.db_uri,
            dim=self.vector_dim,
            enable_sparse=True,
            sparse_embedding_function=BGEM3SparseEmbeddingFunction(),
        )
        logger.info("MilvusVectorStore created.")
        return store

    def build_index(
        self, docs_list: list[Document], vector_store: MilvusVectorStore
    ) -> VectorStoreIndex:
        """
        Build and return a VectorStoreIndex from documents.

        Args:
            docs_list: List of Document objects.
            vector_store: MilvusVectorStore instance.

        Returns:
            VectorStoreIndex object.
        """
        logger.info("Building VectorStoreIndex.")
        storage_context = StorageContext.from_defaults(vector_store=vector_store)
        index = VectorStoreIndex.from_documents(
            documents=docs_list,
            storage_context=storage_context,
            embed_model=self.embed_model,
            transformations=[self.sent_splitter],
            show_progress=True,
            insert_batch_size=self.insert_batch_size,
        )
        logger.info("VectorStoreIndex built.")
        return index

    def run(self) -> None:
        """
        Execute the full pipeline: load data, prepare documents, create vector store, and build index.
        """
        logger.info("Running full vector DB creation pipeline.")
        data = self.load_data()
        docs_list = self.prepare_documents(data=data)
        vector_store = self.create_vector_store()
        self.build_index(docs_list=docs_list, vector_store=vector_store)
        logger.info("Pipeline finished.")


# if __name__ == "__main__":
#     logger.info("Script started.")
#     # Optionally load environment variables if needed
#     _ = load_dotenv(dotenv_path=find_dotenv(raise_error_if_not_found=True))
#     creator = VectorDBCreator(
#         data_path=f"{PROJECT_ROOT_DIR}/data/cs_data_arxiv.json", db_uri="arxiv_docs.db"
#     )
#     creator.run()
#     logger.info("Script finished.")