Dinesh310 commited on
Commit
58316f5
·
verified ·
1 Parent(s): 8ce6ed1

Update src/RAG_builder.py

Browse files
Files changed (1) hide show
  1. src/RAG_builder.py +11 -29
src/RAG_builder.py CHANGED
@@ -29,46 +29,28 @@ class ProjectRAGGraph:
29
  api_key="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" # Keep your API keys safe!
30
  )
31
  self.vector_store = None
32
-
33
  # 2. Initialize Memory Checkpointer
34
  self.memory = MemorySaver()
35
  self.workflow = self._build_graph()
36
 
37
- def process_documents(self, pdf_paths_with_names: list[tuple[str, str]]):
38
- """
39
- Expects a list of tuples: [(temp_path, original_name), ...]
40
- """
41
  all_docs = []
42
- for temp_path, original_name in pdf_paths_with_names:
43
- loader = PyPDFLoader(temp_path)
44
- docs = loader.load()
45
-
46
- # Override the metadata source with the original filename
47
- for doc in docs:
48
- doc.metadata["source"] = original_name
49
-
50
- all_docs.extend(docs)
51
-
52
- splits = RecursiveCharacterTextSplitter(
53
- chunk_size=500,
54
- chunk_overlap=100
55
- ).split_documents(all_docs)
56
 
 
57
  self.vector_store = FAISS.from_documents(splits, self.embeddings)
58
-
59
- # def process_documents(self, pdf_paths):
60
- # all_docs = []
61
- # for path in pdf_paths:
62
- # loader = PyPDFLoader(path)
63
- # all_docs.extend(loader.load())
64
-
65
- # splits = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100).split_documents(all_docs)
66
- # self.vector_store = FAISS.from_documents(splits, self.embeddings)
67
 
68
  # --- GRAPH NODES ---
69
  def retrieve(self, state: GraphState):
70
  print("--- RETRIEVING ---")
71
- retriever = self.vector_store.as_retriever(search_type="mmr", search_kwargs={"k": 5, "lambda_mult":0.25})
 
 
 
72
  documents = retriever.invoke(state["question"])
73
  return {"context": documents}
74
 
 
29
  api_key="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" # Keep your API keys safe!
30
  )
31
  self.vector_store = None
32
+ self.pdf_count = 0
33
  # 2. Initialize Memory Checkpointer
34
  self.memory = MemorySaver()
35
  self.workflow = self._build_graph()
36
 
37
+ def process_documents(self, pdf_paths):
38
+ self.pdf_count = len(pdf_paths) # Track how many PDFs were uploaded
 
 
39
  all_docs = []
40
+ for path in pdf_paths:
41
+ loader = PyPDFLoader(path)
42
+ all_docs.extend(loader.load())
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ splits = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100).split_documents(all_docs)
45
  self.vector_store = FAISS.from_documents(splits, self.embeddings)
 
 
 
 
 
 
 
 
 
46
 
47
  # --- GRAPH NODES ---
48
  def retrieve(self, state: GraphState):
49
  print("--- RETRIEVING ---")
50
+ # Calculate dynamic k
51
+ dynamic_k = self.pdf_count + 2
52
+ k_value = max(1, dynamic_k)
53
+ retriever = self.vector_store.as_retriever(search_type="mmr", search_kwargs={"k": k_value, "lambda_mult":0.25})
54
  documents = retriever.invoke(state["question"])
55
  return {"context": documents}
56