File size: 8,336 Bytes
bab3eba
 
 
82535c7
bab3eba
7210c17
bab3eba
 
 
 
 
 
 
 
7210c17
bab3eba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7210c17
 
bab3eba
 
 
 
 
 
82535c7
 
 
bab3eba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7210c17
 
bab3eba
 
7210c17
 
 
 
 
 
 
bab3eba
 
7210c17
 
 
 
bab3eba
7210c17
bab3eba
7210c17
bab3eba
7210c17
bab3eba
 
7210c17
bab3eba
7210c17
bab3eba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192

import time
from pathlib import Path
from typing import List, Any, Union
import asyncio # Import asyncio for concurrent operations
import faiss 

from llama_index.core import Document, StorageContext, VectorStoreIndex, Settings
from llama_index.core.node_parser import HierarchicalNodeParser, get_leaf_nodes, get_root_nodes
from llama_index.core.retrievers import AutoMergingRetriever, BaseRetriever
from llama_index.core.storage.docstore import SimpleDocumentStore
from llama_index.readers.file import PyMuPDFReader
from llama_index.llms.groq import Groq
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore


class Pipeline:
    """
    A pipeline to process a PDF, create nodes, and generate embeddings.
    It exposes a retriever to fetch nodes for a given query,
    but does not handle the answer generation itself. The embedding
    model is now passed in, not initialized internally.
    """

    def __init__(self, groq_api_key: str, pdf_path: str, embed_model: HuggingFaceEmbedding):
        """
        Initializes the pipeline with API keys, file path, and a pre-initialized embedding model.

        Args:
            groq_api_key (str): Your API key for Groq.
            pdf_path (str): The path to the PDF file to be processed.
            embed_model (HuggingFaceEmbedding): The pre-initialized embedding model.
        """
        self.groq_api_key = groq_api_key
        self.pdf_path = Path(pdf_path)
        self.embed_model = embed_model
        # The embedding dimension for 'all-MiniLM-L6-v2' is 384
        self.d = 384
        # Configure Llama-Index LLM setting only
        Settings.llm = Groq(model="llama3-70b-8192", api_key=self.groq_api_key)
        
        # Initialize components
        self.documents: List[Document] = []
        self.nodes: List[Any] = []
        self.storage_context: Union[StorageContext, None] = None
        self.index: Union[VectorStoreIndex, None] = None
        self.retriever: Union[BaseRetriever, None] = None
        self.leaf_nodes: List[Any] = []
        self.root_nodes: List[Any] = []


    def _parse_pdf(self) -> None:
        """Parses the PDF file into Llama-Index Document objects."""
        print(f"Parsing PDF at: {self.pdf_path}")
        start_time = time.perf_counter()
        loader = PyMuPDFReader()
        docs = loader.load(file_path=self.pdf_path)
        # Concatenate all document parts into a single document for simpler processing
        # Adjust this if you need to maintain per-page document context
        doc_text = "\n\n".join([d.get_content() for d in docs])
        self.documents = [Document(text=doc_text)]
        end_time = time.perf_counter()
        print(f"PDF parsing completed in {end_time - start_time:.2f} seconds.")

    def _create_nodes(self) -> None:
        """Creates hierarchical nodes from the parsed documents."""
        print("Creating nodes from documents...")
        start_time = time.perf_counter()
        node_parser = HierarchicalNodeParser.from_defaults()
        self.nodes = node_parser.get_nodes_from_documents(self.documents)
        self.leaf_nodes = get_leaf_nodes(self.nodes)
        self.root_nodes = get_root_nodes(self.nodes)
        end_time = time.perf_counter()
        print(f"Node creation completed in {end_time - start_time:.2f} seconds.")

    async def _generate_embeddings_concurrently(self) -> None:
        """
        Generates embeddings for leaf nodes concurrently using asyncio.to_thread
        and then builds the VectorStoreIndex.
        """
        print("Generating embeddings for leaf nodes concurrently...")
        start_time_embeddings = time.perf_counter()

        # Define a batch size for sending texts to the embedding model
        # Adjust this based on your system's memory and CPU/GPU capabilities
        BATCH_SIZE = 300 

        embedding_tasks = []
        # Extract text content from leaf nodes
        node_texts = [node.get_content() for node in self.leaf_nodes]

        # Create batches of texts and schedule embedding generation in separate threads
        for i in range(0, len(node_texts), BATCH_SIZE):
            batch_texts = node_texts[i : i + BATCH_SIZE]
            # Use asyncio.to_thread to run the synchronous embedding model call in a separate thread
            # This prevents blocking the main event loop
            embedding_tasks.append(asyncio.to_thread(self.embed_model.get_text_embedding_batch, texts=batch_texts, show_progress=False))

        # Wait for all concurrent embedding tasks to complete
        all_embeddings_batches = await asyncio.gather(*embedding_tasks)

        # Flatten the list of lists of embeddings into a single list
        flat_embeddings = [emb for sublist in all_embeddings_batches for emb in sublist]

        # Assign the generated embeddings back to their respective leaf nodes
        for i, node in enumerate(self.leaf_nodes):
            node.embedding = flat_embeddings[i]

        end_time_embeddings = time.perf_counter()
        print(f"Embeddings generated for {len(self.leaf_nodes)} nodes in {end_time_embeddings - start_time_embeddings:.2f} seconds.")

        # --- FAISS Integration ---
        print("Building VectorStoreIndex with FAISS...")
        start_time_index_build = time.perf_counter()
        
        # 1. Create a FAISS index
        faiss_index = faiss.IndexFlatL2(self.d)
        
        # 2. Create the FaissVectorStore instance
        vector_store = FaissVectorStore(faiss_index=faiss_index)
        
        # 3. Create the StorageContext, passing in our custom vector store
        docstore = SimpleDocumentStore()
        docstore.add_documents(self.nodes)
        self.storage_context = StorageContext.from_defaults(
            docstore=docstore,
            vector_store=vector_store # Use the FAISS vector store
        )
        
        # 4. Build the index. LlamaIndex will now use FaissVectorStore internally.
        self.index = VectorStoreIndex(
            self.leaf_nodes,
            storage_context=self.storage_context,
            embed_model=self.embed_model
        )
        end_time_index_build = time.perf_counter()
        print(f"VectorStoreIndex with FAISS built in {end_time_index_build - start_time_index_build:.2f} seconds.")
        print(f"Total index generation and embedding process completed in {end_time_index_build - start_time_embeddings:.2f} seconds.")
    

    def _setup_retriever(self) -> None:
        """Sets up the retriever."""
        print("Setting up retriever...")
        base_retriever = self.index.as_retriever(similarity_top_k=6)
        self.retriever = AutoMergingRetriever(
            base_retriever, storage_context=self.storage_context, verbose=True
        )

    async def run(self) -> None:
        """Runs the entire pipeline from parsing to retriever setup."""
        if not self.pdf_path.exists():
            raise FileNotFoundError(f"PDF file not found at: {self.pdf_path}")

        self._parse_pdf()
        self._create_nodes()
        await self._generate_embeddings_concurrently() # Await the async embedding generation
        self._setup_retriever()
        print("Pipeline is ready for retrieval.")

    def retrieve_nodes(self, query_str: str) -> List[dict]:
        """
        Retrieves relevant nodes for a given query and converts them to a
        list of dictionaries for external use.

        Args:
            query_str (str): The query string.

        Returns:
            List[dict]: A list of dictionaries with node content and metadata.
        """
        if not self.retriever:
            raise RuntimeError("Retriever is not initialized. Run the pipeline first.")

        print(f"\nRetrieving nodes for query: '{query_str}'")
        start_time = time.perf_counter()
        
        # This is a synchronous call
        nodes = self.retriever.retrieve(query_str)
        
        end_time = time.perf_counter()
        print(f"Retrieval completed in {end_time - start_time:.2f} seconds. Found {len(nodes)} nodes.")

        # Convert the Llama-Index nodes to a dictionary format
        retrieved_results = [
            {
                "content": n.text,
                "document_metadata": n.metadata
            }
            for n in nodes
        ]
        return retrieved_results