File size: 4,879 Bytes
7f9f761
 
 
 
 
 
 
 
 
 
 
07a5601
7f9f761
 
 
 
 
 
 
 
 
3305f51
656f1b2
7f9f761
 
 
 
 
07a5601
7f9f761
 
58316f5
7f9f761
 
 
 
af62a2a
 
7f9f761
af62a2a
 
 
58316f5
af62a2a
 
 
 
 
 
 
 
 
 
 
 
 
 
7f9f761
 
 
af62a2a
 
 
 
 
 
 
 
 
 
7f9f761
 
 
58316f5
 
 
 
7f9f761
 
 
 
 
 
c0f564e
 
 
 
 
 
 
 
 
7f9f761
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
from typing import List, TypedDict
from langgraph.graph import StateGraph, END
# 1. Import MemorySaver for persistence
from langgraph.checkpoint.memory import MemorySaver 
from langchain_community.document_loaders import PyPDFLoader
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_huggingface import HuggingFaceEmbeddings

class GraphState(TypedDict):
    question: str
    context: List[Document]
    answer: str

class ProjectRAGGraph:
    def __init__(self):
        self.embeddings = HuggingFaceEmbeddings(
            model_name="google/embeddinggemma-300m",
            model_kwargs={"device": "cpu"},
            encode_kwargs={"normalize_embeddings": True}
        )
        self.llm = ChatOpenAI(
            model="openai/gpt-oss-120b:free",
            base_url="https://openrouter.ai/api/v1",
            api_key="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" # Keep your API keys safe!
        )
        self.vector_store = None
        self.pdf_count = 0
        # 2. Initialize Memory Checkpointer
        self.memory = MemorySaver()
        self.workflow = self._build_graph()

    def process_documents(self, pdf_paths, original_names=None):
        self.pdf_count = len(pdf_paths)
        all_docs = []
        
        # Iterate through paths and original names simultaneously
        for i, path in enumerate(pdf_paths):
            loader = PyPDFLoader(path)
            docs = loader.load()
            
            # If original names are provided, overwrite the 'source' metadata
            if original_names and i < len(original_names):
                for doc in docs:
                    doc.metadata["source"] = original_names[i]
            
            all_docs.extend(docs)
                
        # Split documents after metadata has been corrected
        splits = RecursiveCharacterTextSplitter(
            chunk_size=500, 
            chunk_overlap=100
        ).split_documents(all_docs)
        
        self.vector_store = FAISS.from_documents(splits, self.embeddings)

    # def process_documents(self, pdf_paths):
    #     self.pdf_count = len(pdf_paths) # Track how many PDFs were uploaded
    #     all_docs = []
    #     for path in pdf_paths:
    #         loader = PyPDFLoader(path)
    #         all_docs.extend(loader.load())
        
    #     splits = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100).split_documents(all_docs)
    #     self.vector_store = FAISS.from_documents(splits, self.embeddings)

    # --- GRAPH NODES ---
    def retrieve(self, state: GraphState):
        print("--- RETRIEVING ---")
        # Calculate dynamic k
        dynamic_k = self.pdf_count + 2
        k_value = max(1, dynamic_k)
        retriever = self.vector_store.as_retriever(search_type="mmr", search_kwargs={"k": k_value, "lambda_mult":0.25})
        documents = retriever.invoke(state["question"])
        return {"context": documents}

    def generate(self, state: GraphState):
        print("--- GENERATING ---")
        prompt = ChatPromptTemplate.from_template("""
        You are an expert Project Analyst.
        Answer ONLY using the provided context from multiple project reports.
        If the answer is not explicitly present, respond with "I don't know."
        When comparing projects, clearly separate insights per project.
        Context:
        {context}
        
        Question:
        {question}
        """)
        
        formatted_context = "\n\n".join(d.page_content for d in state["context"])
        chain = prompt | self.llm
        response = chain.invoke({
            "context": formatted_context, 
            "question": state["question"]
        })
        
        return {"answer": response.content}

    # --- GRAPH CONSTRUCTION ---
    def _build_graph(self):
        workflow = StateGraph(GraphState)

        workflow.add_node("retrieve", self.retrieve)
        workflow.add_node("generate", self.generate)

        workflow.set_entry_point("retrieve")
        workflow.add_edge("retrieve", "generate")
        workflow.add_edge("generate", END)

        # 3. Compile the graph with the checkpointer
        return workflow.compile(checkpointer=self.memory)

    def query(self, question: str, thread_id: str):
        """Executes the graph with a specific thread ID for persistence."""
        # 4. Pass the thread_id in the config
        config = {"configurable": {"thread_id": thread_id}}
        inputs = {"question": question}
        
        # The graph now knows to look up the state for this thread_id
        result = self.workflow.invoke(inputs, config=config)
        return result["answer"]