File size: 4,879 Bytes
7f9f761 07a5601 7f9f761 3305f51 656f1b2 7f9f761 07a5601 7f9f761 58316f5 7f9f761 af62a2a 7f9f761 af62a2a 58316f5 af62a2a 7f9f761 af62a2a 7f9f761 58316f5 7f9f761 c0f564e 7f9f761 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os
from typing import List, TypedDict
from langgraph.graph import StateGraph, END
# 1. Import MemorySaver for persistence
from langgraph.checkpoint.memory import MemorySaver
from langchain_community.document_loaders import PyPDFLoader
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_huggingface import HuggingFaceEmbeddings
class GraphState(TypedDict):
question: str
context: List[Document]
answer: str
class ProjectRAGGraph:
def __init__(self):
self.embeddings = HuggingFaceEmbeddings(
model_name="google/embeddinggemma-300m",
model_kwargs={"device": "cpu"},
encode_kwargs={"normalize_embeddings": True}
)
self.llm = ChatOpenAI(
model="openai/gpt-oss-120b:free",
base_url="https://openrouter.ai/api/v1",
api_key="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" # Keep your API keys safe!
)
self.vector_store = None
self.pdf_count = 0
# 2. Initialize Memory Checkpointer
self.memory = MemorySaver()
self.workflow = self._build_graph()
def process_documents(self, pdf_paths, original_names=None):
self.pdf_count = len(pdf_paths)
all_docs = []
# Iterate through paths and original names simultaneously
for i, path in enumerate(pdf_paths):
loader = PyPDFLoader(path)
docs = loader.load()
# If original names are provided, overwrite the 'source' metadata
if original_names and i < len(original_names):
for doc in docs:
doc.metadata["source"] = original_names[i]
all_docs.extend(docs)
# Split documents after metadata has been corrected
splits = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=100
).split_documents(all_docs)
self.vector_store = FAISS.from_documents(splits, self.embeddings)
# def process_documents(self, pdf_paths):
# self.pdf_count = len(pdf_paths) # Track how many PDFs were uploaded
# all_docs = []
# for path in pdf_paths:
# loader = PyPDFLoader(path)
# all_docs.extend(loader.load())
# splits = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100).split_documents(all_docs)
# self.vector_store = FAISS.from_documents(splits, self.embeddings)
# --- GRAPH NODES ---
def retrieve(self, state: GraphState):
print("--- RETRIEVING ---")
# Calculate dynamic k
dynamic_k = self.pdf_count + 2
k_value = max(1, dynamic_k)
retriever = self.vector_store.as_retriever(search_type="mmr", search_kwargs={"k": k_value, "lambda_mult":0.25})
documents = retriever.invoke(state["question"])
return {"context": documents}
def generate(self, state: GraphState):
print("--- GENERATING ---")
prompt = ChatPromptTemplate.from_template("""
You are an expert Project Analyst.
Answer ONLY using the provided context from multiple project reports.
If the answer is not explicitly present, respond with "I don't know."
When comparing projects, clearly separate insights per project.
Context:
{context}
Question:
{question}
""")
formatted_context = "\n\n".join(d.page_content for d in state["context"])
chain = prompt | self.llm
response = chain.invoke({
"context": formatted_context,
"question": state["question"]
})
return {"answer": response.content}
# --- GRAPH CONSTRUCTION ---
def _build_graph(self):
workflow = StateGraph(GraphState)
workflow.add_node("retrieve", self.retrieve)
workflow.add_node("generate", self.generate)
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", END)
# 3. Compile the graph with the checkpointer
return workflow.compile(checkpointer=self.memory)
def query(self, question: str, thread_id: str):
"""Executes the graph with a specific thread ID for persistence."""
# 4. Pass the thread_id in the config
config = {"configurable": {"thread_id": thread_id}}
inputs = {"question": question}
# The graph now knows to look up the state for this thread_id
result = self.workflow.invoke(inputs, config=config)
return result["answer"] |