Update src/RAG_builder.py
Browse files- src/RAG_builder.py +11 -29
src/RAG_builder.py
CHANGED
|
@@ -29,46 +29,28 @@ class ProjectRAGGraph:
|
|
| 29 |
api_key="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" # Keep your API keys safe!
|
| 30 |
)
|
| 31 |
self.vector_store = None
|
| 32 |
-
|
| 33 |
# 2. Initialize Memory Checkpointer
|
| 34 |
self.memory = MemorySaver()
|
| 35 |
self.workflow = self._build_graph()
|
| 36 |
|
| 37 |
-
def process_documents(self,
|
| 38 |
-
|
| 39 |
-
Expects a list of tuples: [(temp_path, original_name), ...]
|
| 40 |
-
"""
|
| 41 |
all_docs = []
|
| 42 |
-
for
|
| 43 |
-
loader = PyPDFLoader(
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
# Override the metadata source with the original filename
|
| 47 |
-
for doc in docs:
|
| 48 |
-
doc.metadata["source"] = original_name
|
| 49 |
-
|
| 50 |
-
all_docs.extend(docs)
|
| 51 |
-
|
| 52 |
-
splits = RecursiveCharacterTextSplitter(
|
| 53 |
-
chunk_size=500,
|
| 54 |
-
chunk_overlap=100
|
| 55 |
-
).split_documents(all_docs)
|
| 56 |
|
|
|
|
| 57 |
self.vector_store = FAISS.from_documents(splits, self.embeddings)
|
| 58 |
-
|
| 59 |
-
# def process_documents(self, pdf_paths):
|
| 60 |
-
# all_docs = []
|
| 61 |
-
# for path in pdf_paths:
|
| 62 |
-
# loader = PyPDFLoader(path)
|
| 63 |
-
# all_docs.extend(loader.load())
|
| 64 |
-
|
| 65 |
-
# splits = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100).split_documents(all_docs)
|
| 66 |
-
# self.vector_store = FAISS.from_documents(splits, self.embeddings)
|
| 67 |
|
| 68 |
# --- GRAPH NODES ---
|
| 69 |
def retrieve(self, state: GraphState):
|
| 70 |
print("--- RETRIEVING ---")
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
| 72 |
documents = retriever.invoke(state["question"])
|
| 73 |
return {"context": documents}
|
| 74 |
|
|
|
|
| 29 |
api_key="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" # Keep your API keys safe!
|
| 30 |
)
|
| 31 |
self.vector_store = None
|
| 32 |
+
self.pdf_count = 0
|
| 33 |
# 2. Initialize Memory Checkpointer
|
| 34 |
self.memory = MemorySaver()
|
| 35 |
self.workflow = self._build_graph()
|
| 36 |
|
| 37 |
+
def process_documents(self, pdf_paths):
|
| 38 |
+
self.pdf_count = len(pdf_paths) # Track how many PDFs were uploaded
|
|
|
|
|
|
|
| 39 |
all_docs = []
|
| 40 |
+
for path in pdf_paths:
|
| 41 |
+
loader = PyPDFLoader(path)
|
| 42 |
+
all_docs.extend(loader.load())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
+
splits = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100).split_documents(all_docs)
|
| 45 |
self.vector_store = FAISS.from_documents(splits, self.embeddings)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
# --- GRAPH NODES ---
|
| 48 |
def retrieve(self, state: GraphState):
|
| 49 |
print("--- RETRIEVING ---")
|
| 50 |
+
# Calculate dynamic k
|
| 51 |
+
dynamic_k = self.pdf_count + 2
|
| 52 |
+
k_value = max(1, dynamic_k)
|
| 53 |
+
retriever = self.vector_store.as_retriever(search_type="mmr", search_kwargs={"k": k_value, "lambda_mult":0.25})
|
| 54 |
documents = retriever.invoke(state["question"])
|
| 55 |
return {"context": documents}
|
| 56 |
|