|
|
"""LangGraph nodes for RAG workflow""" |
|
|
|
|
|
from src.state.rag_state import RAGState |
|
|
|
|
|
class RAGNodes: |
|
|
"""Contains node functions for RAG workflow""" |
|
|
|
|
|
def __init__(self, retriever, llm): |
|
|
""" |
|
|
Initialize RAG nodes |
|
|
|
|
|
Args: |
|
|
retriever: Document retriever instance |
|
|
llm: Language model instance |
|
|
""" |
|
|
self.retriever = retriever |
|
|
self.llm = llm |
|
|
|
|
|
def retrieve_docs(self, state: RAGState) -> RAGState: |
|
|
""" |
|
|
Retrieve relevant documents node |
|
|
|
|
|
Args: |
|
|
state: Current RAG state |
|
|
|
|
|
Returns: |
|
|
Updated RAG state with retrieved documents |
|
|
""" |
|
|
docs = self.retriever.invoke(state.question) |
|
|
return RAGState( |
|
|
question=state.question, |
|
|
retrieved_docs=docs |
|
|
) |
|
|
|
|
|
def generate_answer(self, state: RAGState) -> RAGState: |
|
|
""" |
|
|
Generate answer from retrieved documents node |
|
|
|
|
|
Args: |
|
|
state: Current RAG state with retrieved documents |
|
|
|
|
|
Returns: |
|
|
Updated RAG state with generated answer |
|
|
""" |
|
|
|
|
|
context = "\n\n".join([doc.page_content for doc in state.retrieved_docs]) |
|
|
|
|
|
|
|
|
prompt = f"""You are a professional Project Analyst. |
|
|
Answer strictly using the context. |
|
|
If unknown, say you don't know. |
|
|
|
|
|
Context: |
|
|
{context} |
|
|
|
|
|
Question: {state.question}""" |
|
|
|
|
|
|
|
|
response = self.llm.invoke(prompt) |
|
|
|
|
|
return RAGState( |
|
|
question=state.question, |
|
|
retrieved_docs=state.retrieved_docs, |
|
|
answer=response.content |
|
|
) |