|
|
import os |
|
|
from typing import List, TypedDict |
|
|
from langgraph.graph import StateGraph, END |
|
|
|
|
|
from langgraph.checkpoint.memory import MemorySaver |
|
|
from langchain_community.document_loaders import PyPDFLoader |
|
|
from langchain_core.documents import Document |
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
|
from langchain_community.vectorstores import FAISS |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
from langchain_openai import ChatOpenAI |
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
|
|
|
class GraphState(TypedDict): |
|
|
question: str |
|
|
context: List[Document] |
|
|
answer: str |
|
|
|
|
|
class ProjectRAGGraph: |
|
|
def __init__(self): |
|
|
self.embeddings = HuggingFaceEmbeddings( |
|
|
model_name="google/embeddinggemma-300m", |
|
|
model_kwargs={"device": "cpu"}, |
|
|
encode_kwargs={"normalize_embeddings": True} |
|
|
) |
|
|
self.llm = ChatOpenAI( |
|
|
model="openai/gpt-oss-120b:free", |
|
|
base_url="https://openrouter.ai/api/v1", |
|
|
api_key="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" |
|
|
) |
|
|
self.vector_store = None |
|
|
self.pdf_count = 0 |
|
|
|
|
|
self.memory = MemorySaver() |
|
|
self.workflow = self._build_graph() |
|
|
|
|
|
def process_documents(self, pdf_paths, original_names=None): |
|
|
self.pdf_count = len(pdf_paths) |
|
|
all_docs = [] |
|
|
|
|
|
|
|
|
for i, path in enumerate(pdf_paths): |
|
|
loader = PyPDFLoader(path) |
|
|
docs = loader.load() |
|
|
|
|
|
|
|
|
if original_names and i < len(original_names): |
|
|
for doc in docs: |
|
|
doc.metadata["source"] = original_names[i] |
|
|
|
|
|
all_docs.extend(docs) |
|
|
|
|
|
|
|
|
splits = RecursiveCharacterTextSplitter( |
|
|
chunk_size=500, |
|
|
chunk_overlap=100 |
|
|
).split_documents(all_docs) |
|
|
|
|
|
self.vector_store = FAISS.from_documents(splits, self.embeddings) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def retrieve(self, state: GraphState): |
|
|
print("--- RETRIEVING ---") |
|
|
|
|
|
dynamic_k = self.pdf_count + 2 |
|
|
k_value = max(1, dynamic_k) |
|
|
retriever = self.vector_store.as_retriever(search_type="mmr", search_kwargs={"k": k_value, "lambda_mult":0.25}) |
|
|
documents = retriever.invoke(state["question"]) |
|
|
return {"context": documents} |
|
|
|
|
|
def generate(self, state: GraphState): |
|
|
print("--- GENERATING ---") |
|
|
prompt = ChatPromptTemplate.from_template(""" |
|
|
You are an expert Project Analyst. |
|
|
Answer ONLY using the provided context from multiple project reports. |
|
|
If the answer is not explicitly present, respond with "I don't know." |
|
|
When comparing projects, clearly separate insights per project. |
|
|
Context: |
|
|
{context} |
|
|
|
|
|
Question: |
|
|
{question} |
|
|
""") |
|
|
|
|
|
formatted_context = "\n\n".join(d.page_content for d in state["context"]) |
|
|
chain = prompt | self.llm |
|
|
response = chain.invoke({ |
|
|
"context": formatted_context, |
|
|
"question": state["question"] |
|
|
}) |
|
|
|
|
|
return {"answer": response.content} |
|
|
|
|
|
|
|
|
def _build_graph(self): |
|
|
workflow = StateGraph(GraphState) |
|
|
|
|
|
workflow.add_node("retrieve", self.retrieve) |
|
|
workflow.add_node("generate", self.generate) |
|
|
|
|
|
workflow.set_entry_point("retrieve") |
|
|
workflow.add_edge("retrieve", "generate") |
|
|
workflow.add_edge("generate", END) |
|
|
|
|
|
|
|
|
return workflow.compile(checkpointer=self.memory) |
|
|
|
|
|
def query(self, question: str, thread_id: str): |
|
|
"""Executes the graph with a specific thread ID for persistence.""" |
|
|
|
|
|
config = {"configurable": {"thread_id": thread_id}} |
|
|
inputs = {"question": question} |
|
|
|
|
|
|
|
|
result = self.workflow.invoke(inputs, config=config) |
|
|
return result["answer"] |