File size: 5,094 Bytes
f68c145 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
"""LangGraph nodes for RAG workflow + ReAct Agent inside generate_content"""
from typing import List
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
class RAGNodes:
"""Graph nodes for LangGraph-based RAG workflow"""
def __init__(self, vector_store, llm):
self.vector_store = vector_store
self.llm = llm
# -------------------------
# RETRIEVE NODE
# -------------------------
def retrieve(self, state: dict) -> dict:
"""Node: Fetch documents from FAISS."""
print("--- RETRIEVING ---")
retriever = self.vector_store.as_retriever(
search_type="mmr",
search_kwargs={"k": 5, "lambda_mult": 0.25}
)
documents: List[Document] = retriever.invoke(state["question"])
return {"context": documents}
# -------------------------
# GENERATE NODE
# -------------------------
def generate(self, state: dict) -> dict:
"""Node: Generate answer using LLM strictly from context."""
print("--- GENERATING ---")
prompt = ChatPromptTemplate.from_template("""
You are a professional Project Analyst.
Use ONLY the following context to answer the question.
If the answer is not in the context, say "I don't know".
Context:
{context}
Question:
{question}
Answer (cite sources if possible):
""")
# Format retrieved documents
formatted_context = "\n\n".join(
f"[{i+1}] {doc.page_content}"
for i, doc in enumerate(state["context"])
)
chain = prompt | self.llm
response = chain.invoke({
"context": formatted_context,
"question": state["question"]
})
return {"answer": response.content}
# from typing import List, Optional
# from src.state.rag_state import RAGState
# from langchain_core.documents import Document
# from langchain_core.tools import Tool
# from langchain_core.messages import HumanMessage
# from langgraph.prebuilt import create_react_agent
# # Wikipedia tool
# from langchain_community.utilities import WikipediaAPIWrapper
# from langchain_community.tools.wikipedia.tool import WikipediaQueryRun
# class RAGNodes:
# """Contains node functions for RAG workflow"""
# def __init__(self, retriever, llm):
# self.retriever = retriever
# self.llm = llm
# self._agent = None # lazy-init agent
# def retrieve_docs(self, state: RAGState) -> RAGState:
# """Classic retriever node"""
# docs = self.retriever.invoke(state.question)
# return RAGState(
# question=state.question,
# retrieved_docs=docs
# )
# def _build_tools(self) -> List[Tool]:
# """Build retriever + wikipedia tools"""
# def retriever_tool_fn(query: str) -> str:
# docs: List[Document] = self.retriever.invoke(query)
# if not docs:
# return "No documents found."
# merged = []
# for i, d in enumerate(docs[:8], start=1):
# meta = d.metadata if hasattr(d, "metadata") else {}
# title = meta.get("title") or meta.get("source") or f"doc_{i}"
# merged.append(f"[{i}] {title}\n{d.page_content}")
# return "\n\n".join(merged)
# retriever_tool = Tool(
# name="retriever",
# description="Fetch passages from indexed corpus.",
# func=retriever_tool_fn,
# )
# wiki = WikipediaQueryRun(
# api_wrapper=WikipediaAPIWrapper(top_k_results=3, lang="en")
# )
# wikipedia_tool = Tool(
# name="wikipedia",
# description="Search Wikipedia for general knowledge.",
# func=wiki.run,
# )
# return [retriever_tool, wikipedia_tool]
# def _build_agent(self):
# """ReAct agent with tools"""
# tools = self._build_tools()
# system_prompt = (
# "You are a helpful RAG agent. "
# "Prefer 'retriever' for user-provided docs; use 'wikipedia' for general knowledge. "
# "Return only the final useful answer."
# )
# self._agent = create_react_agent(self.llm, tools=tools,prompt=system_prompt)
# def generate_answer(self, state: RAGState) -> RAGState:
# """
# Generate answer using ReAct agent with retriever + wikipedia.
# """
# if self._agent is None:
# self._build_agent()
# result = self._agent.invoke({"messages": [HumanMessage(content=state.question)]})
# messages = result.get("messages", [])
# answer: Optional[str] = None
# if messages:
# answer_msg = messages[-1]
# answer = getattr(answer_msg, "content", None)
# return RAGState(
# question=state.question,
# retrieved_docs=state.retrieved_docs,
# answer=answer or "Could not generate answer."
# )
|