PercivalFletcher commited on
Commit
bab3eba
·
verified ·
1 Parent(s): 0266bc6

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +181 -0
pipeline.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import time
3
+ from pathlib import Path
4
+ from typing import List, Any
5
+ import asyncio # Import asyncio for concurrent operations
6
+
7
+ from llama_index.core import Document, StorageContext, VectorStoreIndex, Settings
8
+ from llama_index.core.node_parser import HierarchicalNodeParser, get_leaf_nodes, get_root_nodes
9
+ from llama_index.core.retrievers import AutoMergingRetriever, BaseRetriever
10
+ from llama_index.core.storage.docstore import SimpleDocumentStore
11
+ from llama_index.readers.file import PyMuPDFReader
12
+ from llama_index.llms.groq import Groq
13
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
14
+
15
+
16
+ class Pipeline:
17
+ """
18
+ A pipeline to process a PDF, create nodes, and generate embeddings.
19
+ It exposes a retriever to fetch nodes for a given query,
20
+ but does not handle the answer generation itself. The embedding
21
+ model is now passed in, not initialized internally.
22
+ """
23
+
24
+ def __init__(self, groq_api_key: str, pdf_path: str, embed_model: HuggingFaceEmbedding):
25
+ """
26
+ Initializes the pipeline with API keys, file path, and a pre-initialized embedding model.
27
+
28
+ Args:
29
+ groq_api_key (str): Your API key for Groq.
30
+ pdf_path (str): The path to the PDF file to be processed.
31
+ embed_model (HuggingFaceEmbedding): The pre-initialized embedding model.
32
+ """
33
+ self.groq_api_key = groq_api_key
34
+ self.pdf_path = Path(pdf_path)
35
+ self.embed_model = embed_model
36
+
37
+ # Configure Llama-Index LLM setting only
38
+ Settings.llm = Groq(model="llama3-70b-8192", api_key=self.groq_api_key)
39
+
40
+ # Initialize components
41
+ self.documents: List[Document] = []
42
+ self.nodes: List[Any] = []
43
+ self.storage_context: StorageContext | None = None
44
+ self.index: VectorStoreIndex | None = None
45
+ self.retriever: BaseRetriever | None = None
46
+ self.leaf_nodes: List[Any] = []
47
+ self.root_nodes: List[Any] = []
48
+
49
+
50
+ def _parse_pdf(self) -> None:
51
+ """Parses the PDF file into Llama-Index Document objects."""
52
+ print(f"Parsing PDF at: {self.pdf_path}")
53
+ start_time = time.perf_counter()
54
+ loader = PyMuPDFReader()
55
+ docs = loader.load(file_path=self.pdf_path)
56
+ # Concatenate all document parts into a single document for simpler processing
57
+ # Adjust this if you need to maintain per-page document context
58
+ doc_text = "\n\n".join([d.get_content() for d in docs])
59
+ self.documents = [Document(text=doc_text)]
60
+ end_time = time.perf_counter()
61
+ print(f"PDF parsing completed in {end_time - start_time:.2f} seconds.")
62
+
63
+ def _create_nodes(self) -> None:
64
+ """Creates hierarchical nodes from the parsed documents."""
65
+ print("Creating nodes from documents...")
66
+ start_time = time.perf_counter()
67
+ node_parser = HierarchicalNodeParser.from_defaults()
68
+ self.nodes = node_parser.get_nodes_from_documents(self.documents)
69
+ self.leaf_nodes = get_leaf_nodes(self.nodes)
70
+ self.root_nodes = get_root_nodes(self.nodes)
71
+ end_time = time.perf_counter()
72
+ print(f"Node creation completed in {end_time - start_time:.2f} seconds.")
73
+
74
+ async def _generate_embeddings_concurrently(self) -> None:
75
+ """
76
+ Generates embeddings for leaf nodes concurrently using asyncio.to_thread
77
+ and then builds the VectorStoreIndex.
78
+ """
79
+ print("Generating embeddings for leaf nodes concurrently...")
80
+ start_time_embeddings = time.perf_counter()
81
+
82
+ # Define a batch size for sending texts to the embedding model
83
+ # Adjust this based on your system's memory and CPU/GPU capabilities
84
+ BATCH_SIZE = 300
85
+
86
+ embedding_tasks = []
87
+ # Extract text content from leaf nodes
88
+ node_texts = [node.get_content() for node in self.leaf_nodes]
89
+
90
+ # Create batches of texts and schedule embedding generation in separate threads
91
+ for i in range(0, len(node_texts), BATCH_SIZE):
92
+ batch_texts = node_texts[i : i + BATCH_SIZE]
93
+ # Use asyncio.to_thread to run the synchronous embedding model call in a separate thread
94
+ # This prevents blocking the main event loop
95
+ embedding_tasks.append(asyncio.to_thread(self.embed_model.get_text_embedding_batch, texts=batch_texts, show_progress=False))
96
+
97
+ # Wait for all concurrent embedding tasks to complete
98
+ all_embeddings_batches = await asyncio.gather(*embedding_tasks)
99
+
100
+ # Flatten the list of lists of embeddings into a single list
101
+ flat_embeddings = [emb for sublist in all_embeddings_batches for emb in sublist]
102
+
103
+ # Assign the generated embeddings back to their respective leaf nodes
104
+ for i, node in enumerate(self.leaf_nodes):
105
+ node.embedding = flat_embeddings[i]
106
+
107
+ end_time_embeddings = time.perf_counter()
108
+ print(f"Embeddings generated for {len(self.leaf_nodes)} nodes in {end_time_embeddings - start_time_embeddings:.2f} seconds.")
109
+
110
+ # Now, build the VectorStoreIndex using the nodes that now have pre-computed embeddings
111
+ print("Building VectorStoreIndex...")
112
+ start_time_index_build = time.perf_counter()
113
+
114
+ # Add all nodes (root and leaf) to the document store
115
+ docstore = SimpleDocumentStore()
116
+ docstore.add_documents(self.nodes)
117
+
118
+ self.storage_context = StorageContext.from_defaults(docstore=docstore)
119
+
120
+ # When nodes already have embeddings, VectorStoreIndex will use them
121
+ self.index = VectorStoreIndex(
122
+ self.leaf_nodes, # Pass leaf nodes which now contain their embeddings
123
+ storage_context=self.storage_context,
124
+ embed_model=self.embed_model # Still pass the embed_model, though it won't re-embed if nodes have embeddings
125
+ )
126
+ end_time_index_build = time.perf_counter()
127
+ print(f"VectorStoreIndex built in {end_time_index_build - start_time_index_build:.2f} seconds.")
128
+ print(f"Total index generation and embedding process completed in {end_time_index_build - start_time_embeddings:.2f} seconds.")
129
+
130
+
131
+ def _setup_retriever(self) -> None:
132
+ """Sets up the retriever."""
133
+ print("Setting up retriever...")
134
+ base_retriever = self.index.as_retriever(similarity_top_k=6)
135
+ self.retriever = AutoMergingRetriever(
136
+ base_retriever, storage_context=self.storage_context, verbose=True
137
+ )
138
+
139
+ async def run(self) -> None:
140
+ """Runs the entire pipeline from parsing to retriever setup."""
141
+ if not self.pdf_path.exists():
142
+ raise FileNotFoundError(f"PDF file not found at: {self.pdf_path}")
143
+
144
+ self._parse_pdf()
145
+ self._create_nodes()
146
+ await self._generate_embeddings_concurrently() # Await the async embedding generation
147
+ self._setup_retriever()
148
+ print("Pipeline is ready for retrieval.")
149
+
150
+ def retrieve_nodes(self, query_str: str) -> List[dict]:
151
+ """
152
+ Retrieves relevant nodes for a given query and converts them to a
153
+ list of dictionaries for external use.
154
+
155
+ Args:
156
+ query_str (str): The query string.
157
+
158
+ Returns:
159
+ List[dict]: A list of dictionaries with node content and metadata.
160
+ """
161
+ if not self.retriever:
162
+ raise RuntimeError("Retriever is not initialized. Run the pipeline first.")
163
+
164
+ print(f"\nRetrieving nodes for query: '{query_str}'")
165
+ start_time = time.perf_counter()
166
+
167
+ # This is a synchronous call
168
+ nodes = self.retriever.retrieve(query_str)
169
+
170
+ end_time = time.perf_counter()
171
+ print(f"Retrieval completed in {end_time - start_time:.2f} seconds. Found {len(nodes)} nodes.")
172
+
173
+ # Convert the Llama-Index nodes to a dictionary format
174
+ retrieved_results = [
175
+ {
176
+ "content": n.text,
177
+ "document_metadata": n.metadata
178
+ }
179
+ for n in nodes
180
+ ]
181
+ return retrieved_results