Demo_1 / src /RAG_builder.py
Dinesh310's picture
Update src/RAG_builder.py
3305f51 verified
import os
from typing import List, TypedDict
from langgraph.graph import StateGraph, END
# 1. Import MemorySaver for persistence
from langgraph.checkpoint.memory import MemorySaver
from langchain_community.document_loaders import PyPDFLoader
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_huggingface import HuggingFaceEmbeddings
class GraphState(TypedDict):
question: str
context: List[Document]
answer: str
class ProjectRAGGraph:
def __init__(self):
self.embeddings = HuggingFaceEmbeddings(
model_name="google/embeddinggemma-300m",
model_kwargs={"device": "cpu"},
encode_kwargs={"normalize_embeddings": True}
)
self.llm = ChatOpenAI(
model="openai/gpt-oss-120b:free",
base_url="https://openrouter.ai/api/v1",
api_key="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" # Keep your API keys safe!
)
self.vector_store = None
self.pdf_count = 0
# 2. Initialize Memory Checkpointer
self.memory = MemorySaver()
self.workflow = self._build_graph()
def process_documents(self, pdf_paths, original_names=None):
self.pdf_count = len(pdf_paths)
all_docs = []
# Iterate through paths and original names simultaneously
for i, path in enumerate(pdf_paths):
loader = PyPDFLoader(path)
docs = loader.load()
# If original names are provided, overwrite the 'source' metadata
if original_names and i < len(original_names):
for doc in docs:
doc.metadata["source"] = original_names[i]
all_docs.extend(docs)
# Split documents after metadata has been corrected
splits = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=100
).split_documents(all_docs)
self.vector_store = FAISS.from_documents(splits, self.embeddings)
# def process_documents(self, pdf_paths):
# self.pdf_count = len(pdf_paths) # Track how many PDFs were uploaded
# all_docs = []
# for path in pdf_paths:
# loader = PyPDFLoader(path)
# all_docs.extend(loader.load())
# splits = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100).split_documents(all_docs)
# self.vector_store = FAISS.from_documents(splits, self.embeddings)
# --- GRAPH NODES ---
def retrieve(self, state: GraphState):
print("--- RETRIEVING ---")
# Calculate dynamic k
dynamic_k = self.pdf_count + 2
k_value = max(1, dynamic_k)
retriever = self.vector_store.as_retriever(search_type="mmr", search_kwargs={"k": k_value, "lambda_mult":0.25})
documents = retriever.invoke(state["question"])
return {"context": documents}
def generate(self, state: GraphState):
print("--- GENERATING ---")
prompt = ChatPromptTemplate.from_template("""
You are an expert Project Analyst.
Answer ONLY using the provided context from multiple project reports.
If the answer is not explicitly present, respond with "I don't know."
When comparing projects, clearly separate insights per project.
Context:
{context}
Question:
{question}
""")
formatted_context = "\n\n".join(d.page_content for d in state["context"])
chain = prompt | self.llm
response = chain.invoke({
"context": formatted_context,
"question": state["question"]
})
return {"answer": response.content}
# --- GRAPH CONSTRUCTION ---
def _build_graph(self):
workflow = StateGraph(GraphState)
workflow.add_node("retrieve", self.retrieve)
workflow.add_node("generate", self.generate)
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", END)
# 3. Compile the graph with the checkpointer
return workflow.compile(checkpointer=self.memory)
def query(self, question: str, thread_id: str):
"""Executes the graph with a specific thread ID for persistence."""
# 4. Pass the thread_id in the config
config = {"configurable": {"thread_id": thread_id}}
inputs = {"question": question}
# The graph now knows to look up the state for this thread_id
result = self.workflow.invoke(inputs, config=config)
return result["answer"]