import os from typing import List, TypedDict from langgraph.graph import StateGraph, END # 1. Import MemorySaver for persistence from langgraph.checkpoint.memory import MemorySaver from langchain_community.document_loaders import PyPDFLoader from langchain_core.documents import Document from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.vectorstores import FAISS from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI from langchain_huggingface import HuggingFaceEmbeddings class GraphState(TypedDict): question: str context: List[Document] answer: str class ProjectRAGGraph: def __init__(self): self.embeddings = HuggingFaceEmbeddings( model_name="google/embeddinggemma-300m", model_kwargs={"device": "cpu"}, encode_kwargs={"normalize_embeddings": True} ) self.llm = ChatOpenAI( model="openai/gpt-oss-120b:free", base_url="https://openrouter.ai/api/v1", api_key="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" # Keep your API keys safe! ) self.vector_store = None self.pdf_count = 0 # 2. Initialize Memory Checkpointer self.memory = MemorySaver() self.workflow = self._build_graph() def process_documents(self, pdf_paths, original_names=None): self.pdf_count = len(pdf_paths) all_docs = [] # Iterate through paths and original names simultaneously for i, path in enumerate(pdf_paths): loader = PyPDFLoader(path) docs = loader.load() # If original names are provided, overwrite the 'source' metadata if original_names and i < len(original_names): for doc in docs: doc.metadata["source"] = original_names[i] all_docs.extend(docs) # Split documents after metadata has been corrected splits = RecursiveCharacterTextSplitter( chunk_size=500, chunk_overlap=100 ).split_documents(all_docs) self.vector_store = FAISS.from_documents(splits, self.embeddings) # def process_documents(self, pdf_paths): # self.pdf_count = len(pdf_paths) # Track how many PDFs were uploaded # all_docs = [] # for path in pdf_paths: # loader = PyPDFLoader(path) # all_docs.extend(loader.load()) # splits = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100).split_documents(all_docs) # self.vector_store = FAISS.from_documents(splits, self.embeddings) # --- GRAPH NODES --- def retrieve(self, state: GraphState): print("--- RETRIEVING ---") # Calculate dynamic k dynamic_k = self.pdf_count + 2 k_value = max(1, dynamic_k) retriever = self.vector_store.as_retriever(search_type="mmr", search_kwargs={"k": k_value, "lambda_mult":0.25}) documents = retriever.invoke(state["question"]) return {"context": documents} def generate(self, state: GraphState): print("--- GENERATING ---") prompt = ChatPromptTemplate.from_template(""" You are an expert Project Analyst. Answer ONLY using the provided context from multiple project reports. If the answer is not explicitly present, respond with "I don't know." When comparing projects, clearly separate insights per project. Context: {context} Question: {question} """) formatted_context = "\n\n".join(d.page_content for d in state["context"]) chain = prompt | self.llm response = chain.invoke({ "context": formatted_context, "question": state["question"] }) return {"answer": response.content} # --- GRAPH CONSTRUCTION --- def _build_graph(self): workflow = StateGraph(GraphState) workflow.add_node("retrieve", self.retrieve) workflow.add_node("generate", self.generate) workflow.set_entry_point("retrieve") workflow.add_edge("retrieve", "generate") workflow.add_edge("generate", END) # 3. Compile the graph with the checkpointer return workflow.compile(checkpointer=self.memory) def query(self, question: str, thread_id: str): """Executes the graph with a specific thread ID for persistence.""" # 4. Pass the thread_id in the config config = {"configurable": {"thread_id": thread_id}} inputs = {"question": question} # The graph now knows to look up the state for this thread_id result = self.workflow.invoke(inputs, config=config) return result["answer"]