Dinesh310 commited on
Commit
7f9f761
·
verified ·
1 Parent(s): cfb3c7f

Create RAG_builder.py

Browse files
Files changed (1) hide show
  1. src/RAG_builder.py +93 -0
src/RAG_builder.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, TypedDict
3
+ from langgraph.graph import StateGraph, END
4
+ # 1. Import MemorySaver for persistence
5
+ from langgraph.checkpoint.memory import MemorySaver
6
+ from langchain_community.document_loaders import PyPDFLoader
7
+ from langchain_core.documents import Document
8
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
9
+ from langchain_community.vectorstores import FAISS
10
+ from langchain_core.prompts import ChatPromptTemplate
11
+ from langchain_openai import ChatOpenAI
12
+ from langchain_community.embeddings import HuggingFaceEmbeddings
13
+
14
+ class GraphState(TypedDict):
15
+ question: str
16
+ context: List[Document]
17
+ answer: str
18
+
19
+ class ProjectRAGGraph:
20
+ def __init__(self):
21
+ self.embeddings = HuggingFaceEmbeddings(
22
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
23
+ model_kwargs={"device": "cpu"},
24
+ encode_kwargs={"normalize_embeddings": True}
25
+ )
26
+ self.llm = ChatOpenAI(
27
+ model="openai/gpt-oss-120b:free",
28
+ base_url="https://openrouter.ai/api/v1",
29
+ api_key="your-api-key" # Keep your API keys safe!
30
+ )
31
+ self.vector_store = None
32
+
33
+ # 2. Initialize Memory Checkpointer
34
+ self.memory = MemorySaver()
35
+ self.workflow = self._build_graph()
36
+
37
+ def process_documents(self, pdf_paths):
38
+ all_docs = []
39
+ for path in pdf_paths:
40
+ loader = PyPDFLoader(path)
41
+ all_docs.extend(loader.load())
42
+
43
+ splits = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100).split_documents(all_docs)
44
+ self.vector_store = FAISS.from_documents(splits, self.embeddings)
45
+
46
+ # --- GRAPH NODES ---
47
+ def retrieve(self, state: GraphState):
48
+ print("--- RETRIEVING ---")
49
+ retriever = self.vector_store.as_retriever(search_type="mmr", search_kwargs={"k": 5, "lambda_mult":0.25})
50
+ documents = retriever.invoke(state["question"])
51
+ return {"context": documents}
52
+
53
+ def generate(self, state: GraphState):
54
+ print("--- GENERATING ---")
55
+ prompt = ChatPromptTemplate.from_template("""
56
+ You are a professional Project Analyst.
57
+ Context: {context}
58
+ Question: {question}
59
+ Answer strictly using the context. Cite sources.
60
+ """)
61
+
62
+ formatted_context = "\n\n".join(d.page_content for d in state["context"])
63
+ chain = prompt | self.llm
64
+ response = chain.invoke({
65
+ "context": formatted_context,
66
+ "question": state["question"]
67
+ })
68
+
69
+ return {"answer": response.content}
70
+
71
+ # --- GRAPH CONSTRUCTION ---
72
+ def _build_graph(self):
73
+ workflow = StateGraph(GraphState)
74
+
75
+ workflow.add_node("retrieve", self.retrieve)
76
+ workflow.add_node("generate", self.generate)
77
+
78
+ workflow.set_entry_point("retrieve")
79
+ workflow.add_edge("retrieve", "generate")
80
+ workflow.add_edge("generate", END)
81
+
82
+ # 3. Compile the graph with the checkpointer
83
+ return workflow.compile(checkpointer=self.memory)
84
+
85
+ def query(self, question: str, thread_id: str):
86
+ """Executes the graph with a specific thread ID for persistence."""
87
+ # 4. Pass the thread_id in the config
88
+ config = {"configurable": {"thread_id": thread_id}}
89
+ inputs = {"question": question}
90
+
91
+ # The graph now knows to look up the state for this thread_id
92
+ result = self.workflow.invoke(inputs, config=config)
93
+ return result["answer"]