Dinesh310 commited on
Commit
f68c145
·
verified ·
1 Parent(s): b5135ee

Update src/node/reactnode.py

Browse files
Files changed (1) hide show
  1. src/node/reactnode.py +160 -92
src/node/reactnode.py CHANGED
@@ -1,92 +1,160 @@
1
- """LangGraph nodes for RAG workflow + ReAct Agent inside generate_content"""
2
-
3
- from typing import List, Optional
4
- from src.state.rag_state import RAGState
5
-
6
- from langchain_core.documents import Document
7
- from langchain_core.tools import Tool
8
- from langchain_core.messages import HumanMessage
9
- from langgraph.prebuilt import create_react_agent
10
-
11
- # Wikipedia tool
12
- from langchain_community.utilities import WikipediaAPIWrapper
13
- from langchain_community.tools.wikipedia.tool import WikipediaQueryRun
14
-
15
-
16
- class RAGNodes:
17
- """Contains node functions for RAG workflow"""
18
-
19
- def __init__(self, retriever, llm):
20
- self.retriever = retriever
21
- self.llm = llm
22
- self._agent = None # lazy-init agent
23
-
24
- def retrieve_docs(self, state: RAGState) -> RAGState:
25
- """Classic retriever node"""
26
- docs = self.retriever.invoke(state.question)
27
- return RAGState(
28
- question=state.question,
29
- retrieved_docs=docs
30
- )
31
-
32
- def _build_tools(self) -> List[Tool]:
33
- """Build retriever + wikipedia tools"""
34
-
35
- def retriever_tool_fn(query: str) -> str:
36
- docs: List[Document] = self.retriever.invoke(query)
37
- if not docs:
38
- return "No documents found."
39
- merged = []
40
- for i, d in enumerate(docs[:8], start=1):
41
- meta = d.metadata if hasattr(d, "metadata") else {}
42
- title = meta.get("title") or meta.get("source") or f"doc_{i}"
43
- merged.append(f"[{i}] {title}\n{d.page_content}")
44
- return "\n\n".join(merged)
45
-
46
- retriever_tool = Tool(
47
- name="retriever",
48
- description="Fetch passages from indexed corpus.",
49
- func=retriever_tool_fn,
50
- )
51
-
52
- wiki = WikipediaQueryRun(
53
- api_wrapper=WikipediaAPIWrapper(top_k_results=3, lang="en")
54
- )
55
- wikipedia_tool = Tool(
56
- name="wikipedia",
57
- description="Search Wikipedia for general knowledge.",
58
- func=wiki.run,
59
- )
60
-
61
- return [retriever_tool, wikipedia_tool]
62
-
63
- def _build_agent(self):
64
- """ReAct agent with tools"""
65
- tools = self._build_tools()
66
- system_prompt = (
67
- "You are a helpful RAG agent. "
68
- "Prefer 'retriever' for user-provided docs; use 'wikipedia' for general knowledge. "
69
- "Return only the final useful answer."
70
- )
71
- self._agent = create_react_agent(self.llm, tools=tools,prompt=system_prompt)
72
-
73
- def generate_answer(self, state: RAGState) -> RAGState:
74
- """
75
- Generate answer using ReAct agent with retriever + wikipedia.
76
- """
77
- if self._agent is None:
78
- self._build_agent()
79
-
80
- result = self._agent.invoke({"messages": [HumanMessage(content=state.question)]})
81
-
82
- messages = result.get("messages", [])
83
- answer: Optional[str] = None
84
- if messages:
85
- answer_msg = messages[-1]
86
- answer = getattr(answer_msg, "content", None)
87
-
88
- return RAGState(
89
- question=state.question,
90
- retrieved_docs=state.retrieved_docs,
91
- answer=answer or "Could not generate answer."
92
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LangGraph nodes for RAG workflow + ReAct Agent inside generate_content"""
2
+
3
+ from typing import List
4
+ from langchain_core.documents import Document
5
+ from langchain_core.prompts import ChatPromptTemplate
6
+
7
+ class RAGNodes:
8
+ """Graph nodes for LangGraph-based RAG workflow"""
9
+
10
+ def __init__(self, vector_store, llm):
11
+ self.vector_store = vector_store
12
+ self.llm = llm
13
+
14
+ # -------------------------
15
+ # RETRIEVE NODE
16
+ # -------------------------
17
+ def retrieve(self, state: dict) -> dict:
18
+ """Node: Fetch documents from FAISS."""
19
+ print("--- RETRIEVING ---")
20
+
21
+ retriever = self.vector_store.as_retriever(
22
+ search_type="mmr",
23
+ search_kwargs={"k": 5, "lambda_mult": 0.25}
24
+ )
25
+
26
+ documents: List[Document] = retriever.invoke(state["question"])
27
+ return {"context": documents}
28
+
29
+ # -------------------------
30
+ # GENERATE NODE
31
+ # -------------------------
32
+ def generate(self, state: dict) -> dict:
33
+ """Node: Generate answer using LLM strictly from context."""
34
+ print("--- GENERATING ---")
35
+
36
+ prompt = ChatPromptTemplate.from_template("""
37
+ You are a professional Project Analyst.
38
+
39
+ Use ONLY the following context to answer the question.
40
+ If the answer is not in the context, say "I don't know".
41
+
42
+ Context:
43
+ {context}
44
+
45
+ Question:
46
+ {question}
47
+
48
+ Answer (cite sources if possible):
49
+ """)
50
+
51
+ # Format retrieved documents
52
+ formatted_context = "\n\n".join(
53
+ f"[{i+1}] {doc.page_content}"
54
+ for i, doc in enumerate(state["context"])
55
+ )
56
+
57
+ chain = prompt | self.llm
58
+
59
+ response = chain.invoke({
60
+ "context": formatted_context,
61
+ "question": state["question"]
62
+ })
63
+
64
+ return {"answer": response.content}
65
+
66
+
67
+
68
+
69
+
70
+
71
+ # from typing import List, Optional
72
+ # from src.state.rag_state import RAGState
73
+
74
+ # from langchain_core.documents import Document
75
+ # from langchain_core.tools import Tool
76
+ # from langchain_core.messages import HumanMessage
77
+ # from langgraph.prebuilt import create_react_agent
78
+
79
+ # # Wikipedia tool
80
+ # from langchain_community.utilities import WikipediaAPIWrapper
81
+ # from langchain_community.tools.wikipedia.tool import WikipediaQueryRun
82
+
83
+
84
+ # class RAGNodes:
85
+ # """Contains node functions for RAG workflow"""
86
+
87
+ # def __init__(self, retriever, llm):
88
+ # self.retriever = retriever
89
+ # self.llm = llm
90
+ # self._agent = None # lazy-init agent
91
+
92
+ # def retrieve_docs(self, state: RAGState) -> RAGState:
93
+ # """Classic retriever node"""
94
+ # docs = self.retriever.invoke(state.question)
95
+ # return RAGState(
96
+ # question=state.question,
97
+ # retrieved_docs=docs
98
+ # )
99
+
100
+ # def _build_tools(self) -> List[Tool]:
101
+ # """Build retriever + wikipedia tools"""
102
+
103
+ # def retriever_tool_fn(query: str) -> str:
104
+ # docs: List[Document] = self.retriever.invoke(query)
105
+ # if not docs:
106
+ # return "No documents found."
107
+ # merged = []
108
+ # for i, d in enumerate(docs[:8], start=1):
109
+ # meta = d.metadata if hasattr(d, "metadata") else {}
110
+ # title = meta.get("title") or meta.get("source") or f"doc_{i}"
111
+ # merged.append(f"[{i}] {title}\n{d.page_content}")
112
+ # return "\n\n".join(merged)
113
+
114
+ # retriever_tool = Tool(
115
+ # name="retriever",
116
+ # description="Fetch passages from indexed corpus.",
117
+ # func=retriever_tool_fn,
118
+ # )
119
+
120
+ # wiki = WikipediaQueryRun(
121
+ # api_wrapper=WikipediaAPIWrapper(top_k_results=3, lang="en")
122
+ # )
123
+ # wikipedia_tool = Tool(
124
+ # name="wikipedia",
125
+ # description="Search Wikipedia for general knowledge.",
126
+ # func=wiki.run,
127
+ # )
128
+
129
+ # return [retriever_tool, wikipedia_tool]
130
+
131
+ # def _build_agent(self):
132
+ # """ReAct agent with tools"""
133
+ # tools = self._build_tools()
134
+ # system_prompt = (
135
+ # "You are a helpful RAG agent. "
136
+ # "Prefer 'retriever' for user-provided docs; use 'wikipedia' for general knowledge. "
137
+ # "Return only the final useful answer."
138
+ # )
139
+ # self._agent = create_react_agent(self.llm, tools=tools,prompt=system_prompt)
140
+
141
+ # def generate_answer(self, state: RAGState) -> RAGState:
142
+ # """
143
+ # Generate answer using ReAct agent with retriever + wikipedia.
144
+ # """
145
+ # if self._agent is None:
146
+ # self._build_agent()
147
+
148
+ # result = self._agent.invoke({"messages": [HumanMessage(content=state.question)]})
149
+
150
+ # messages = result.get("messages", [])
151
+ # answer: Optional[str] = None
152
+ # if messages:
153
+ # answer_msg = messages[-1]
154
+ # answer = getattr(answer_msg, "content", None)
155
+
156
+ # return RAGState(
157
+ # question=state.question,
158
+ # retrieved_docs=state.retrieved_docs,
159
+ # answer=answer or "Could not generate answer."
160
+ # )