| | """LangGraph nodes for RAG workflow + ReAct Agent inside generate_content"""
|
| |
|
| | from typing import List, Optional
|
| | from src.state.rag_state import RAGState
|
| |
|
| | from langchain_core.documents import Document
|
| | from langchain_core.tools import Tool
|
| | from langchain_core.messages import HumanMessage
|
| | from langgraph.prebuilt import create_react_agent
|
| |
|
| |
|
| | from langchain_community.utilities import WikipediaAPIWrapper
|
| | from langchain_community.tools.wikipedia.tool import WikipediaQueryRun
|
| |
|
| |
|
| | class RAGNodes:
|
| | """Contains node functions for RAG workflow"""
|
| |
|
| | def __init__(self, retriever, llm):
|
| | self.retriever = retriever
|
| | self.llm = llm
|
| | self._agent = None
|
| |
|
| | def retrieve_docs(self, state: RAGState) -> RAGState:
|
| | """Classic retriever node"""
|
| | docs = self.retriever.invoke(state.question)
|
| | return RAGState(
|
| | question=state.question,
|
| | retrieved_docs=docs
|
| | )
|
| |
|
| | def _build_tools(self) -> List[Tool]:
|
| | """Build retriever + wikipedia tools"""
|
| |
|
| | def retriever_tool_fn(query: str) -> str:
|
| | docs: List[Document] = self.retriever.invoke(query)
|
| | if not docs:
|
| | return "No documents found."
|
| | merged = []
|
| | for i, d in enumerate(docs[:8], start=1):
|
| | meta = d.metadata if hasattr(d, "metadata") else {}
|
| | title = meta.get("title") or meta.get("source") or f"doc_{i}"
|
| | merged.append(f"[{i}] {title}\n{d.page_content}")
|
| | return "\n\n".join(merged)
|
| |
|
| | retriever_tool = Tool(
|
| | name="retriever",
|
| | description="Fetch passages from indexed corpus.",
|
| | func=retriever_tool_fn,
|
| | )
|
| |
|
| | wiki = WikipediaQueryRun(
|
| | api_wrapper=WikipediaAPIWrapper(top_k_results=3, lang="en")
|
| | )
|
| | wikipedia_tool = Tool(
|
| | name="wikipedia",
|
| | description="Search Wikipedia for general knowledge.",
|
| | func=wiki.run,
|
| | )
|
| |
|
| | return [retriever_tool, wikipedia_tool]
|
| |
|
| | def _build_agent(self):
|
| | """ReAct agent with tools"""
|
| | tools = self._build_tools()
|
| | system_prompt = (
|
| | "You are a helpful RAG agent. "
|
| | "Prefer 'retriever' for user-provided docs; use 'wikipedia' for general knowledge. "
|
| | "Return only the final useful answer."
|
| | )
|
| | self._agent = create_react_agent(self.llm, tools=tools,prompt=system_prompt)
|
| |
|
| | def generate_answer(self, state: RAGState) -> RAGState:
|
| | """
|
| | Generate answer using ReAct agent with retriever + wikipedia.
|
| | """
|
| | if self._agent is None:
|
| | self._build_agent()
|
| |
|
| | result = self._agent.invoke({"messages": [HumanMessage(content=state.question)]})
|
| |
|
| | messages = result.get("messages", [])
|
| | answer: Optional[str] = None
|
| | if messages:
|
| | answer_msg = messages[-1]
|
| | answer = getattr(answer_msg, "content", None)
|
| |
|
| | return RAGState(
|
| | question=state.question,
|
| | retrieved_docs=state.retrieved_docs,
|
| | answer=answer or "Could not generate answer."
|
| | )
|
| |
|