mabelwang21 commited on
Commit
f9bd4a9
·
1 Parent(s): e656aa6

test RAG in agent

Browse files
Files changed (1) hide show
  1. agent.py +102 -40
agent.py CHANGED
@@ -9,8 +9,18 @@ from langchain.tools import tool
9
  from langchain_community.document_loaders import (
10
  CSVLoader,
11
  YoutubeLoader,
 
 
 
 
 
 
 
 
 
12
  )
13
 
 
14
  from langchain.chat_models import init_chat_model
15
  from langchain.agents import initialize_agent, AgentType
16
  from langchain_community.retrievers import BM25Retriever
@@ -165,7 +175,7 @@ class AgentState(TypedDict):
165
  # The document provided
166
  input_file: Optional[str] # Contains file path (PDF/PNG)
167
  messages: Annotated[list[AnyMessage], add_messages]
168
-
169
  # === Agent Class ===
170
  class MyAgent:
171
  def __init__(
@@ -175,51 +185,103 @@ class MyAgent:
175
  ):
176
  # Initialize LLM
177
  self.llm = init_chat_model(model_name, temperature=temperature)
 
 
 
 
 
178
 
179
- # Base tools: use provided tools or default list
180
- self.tools = tools
181
-
182
- # Human-readable tool descriptions
183
- self.textual_tool_desc = "\n".join(t.__doc__.strip() for t in self.tools)
184
-
185
- # Define assistant node
186
- def assistant_node(state: AgentState) -> dict:
187
- sys_msg = SystemMessage(
188
- content="\n".join([
189
- SYSTEM_PROMPT,
190
- "\nTools available:\n" + self.textual_tool_desc
191
- ])
192
- )
193
- msgs = [sys_msg] + state["messages"]
194
- response = self.llm(msgs)
195
- return {"messages": state["messages"] + [response], "input_file": state.get("input_file")}
196
-
197
- # Condition to invoke tools: check if last LLM message mentions a tool invocation
198
- def needs_tool(state: AgentState) -> bool:
199
- last = state["messages"][-1].content.lower()
200
- return any(f"{t.__name__.lower()}(" in last for t in self.tools)
201
-
202
- # Build the state graph
203
- builder = StateGraph(AgentState)
204
- builder.add_node("assistant", assistant_node)
205
- builder.add_node("tools", ToolNode(self.tools))
206
- builder.add_edge(START, "assistant")
207
- builder.add_conditional_edges("assistant", needs_tool)
208
- builder.add_edge("tools", "assistant")
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
- self.react_graph = builder.compile()
211
 
212
  def __call__(
213
  self,
214
- user_input: str,
215
- input_file: Optional[str] = None,
216
  ) -> str:
217
- state = AgentState()
218
- state["messages"] = [HumanMessage(content=user_input)]
219
- state["input_file"] = input_file
220
- out = self.react_graph(state)
221
- # Return only the final LLM message content
222
- return out["messages"][-1].content.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
  # CLI entrypoint
225
  if __name__ == "__main__":
 
9
  from langchain_community.document_loaders import (
10
  CSVLoader,
11
  YoutubeLoader,
12
+ PyPDFLoader
13
+ )
14
+ from langchain_community.document_loaders.blob_loaders.youtube_audio import (
15
+ YoutubeAudioLoader,
16
+ )
17
+ from langchain_community.document_loaders.generic import GenericLoader
18
+ from langchain_community.document_loaders.parsers.audio import (
19
+ OpenAIWhisperParser,
20
+ OpenAIWhisperParserLocal,
21
  )
22
 
23
+ #from langchain_community.document_loaders import AudioLoader, WhisperLoader
24
  from langchain.chat_models import init_chat_model
25
  from langchain.agents import initialize_agent, AgentType
26
  from langchain_community.retrievers import BM25Retriever
 
175
  # The document provided
176
  input_file: Optional[str] # Contains file path (PDF/PNG)
177
  messages: Annotated[list[AnyMessage], add_messages]
178
+
179
  # === Agent Class ===
180
  class MyAgent:
181
  def __init__(
 
185
  ):
186
  # Initialize LLM
187
  self.llm = init_chat_model(model_name, temperature=temperature)
188
+ # Base tools
189
+ self.tools = tools
190
+ # RAG components
191
+ self.docs: List[Any] = []
192
+ self.retriever: Optional[BM25Retriever] = None
193
 
194
+ def add_files(self, file_paths: List[str]):
195
+ """
196
+ Load and index documents for RAG based on file extensions or URLs.
197
+ Supports: PDF, CSV, audio (mp3/wav), and YouTube URLs.
198
+ """
199
+ for path in file_paths:
200
+ ext = Path(path).suffix.lower()
201
+ if ext == ".csv":
202
+ loader = CSVLoader(path)
203
+ self.docs.extend(loader.load())
204
+ elif ext == ".pdf":
205
+ loader = PyPDFLoader(path)
206
+ self.docs.extend(loader.load())
207
+ elif ext in [".mp3", ".wav"]:
208
+ audio_docs = AudioLoader(path).load()
209
+ self.docs.extend(WhisperLoader().load(audio_docs))
210
+ elif "youtube" in path:
211
+ loader = YoutubeLoader.from_youtube_url(path)
212
+ self.docs.extend(loader.load())
213
+ else:
214
+ continue
215
+
216
+ def build_retriever(self):
217
+ """
218
+ Create BM25Retriever over the loaded documents and register rag_search tool.
219
+ """
220
+ if not self.docs:
221
+ return
222
+ self.retriever = BM25Retriever.from_documents(self.docs)
223
+
224
+ @tool
225
+ def rag_search(query: str) -> str:
226
+ """
227
+ Retrieve top-3 relevant document chunks via BM25.
228
+ """
229
+ res = self.retriever.invoke(query)
230
+ if res:
231
+ return "\n\n".join([doc.page_content for doc in res[:3]])
232
+ return ""
233
+
234
+ # Register RAG tool
235
+ self.tools.append(rag_search)
236
 
 
237
 
238
  def __call__(
239
  self,
240
+ question: str,
241
+ file_paths: Optional[List[str]] = None
242
  ) -> str:
243
+ # Prepare state graph
244
+ state: Dict[str, Any] = {"messages": [], "input_file": None}
245
+
246
+ # Add system message
247
+ tool_desc = "\n".join(f"{tool_func.__name__}: {tool_func.__doc__.strip()}" \
248
+ for tool_func in self.tools)
249
+ sys_msg = SystemMessage(content=f"{SYSTEM_PROMPT}\n\nTools:\n{tool_desc}")
250
+ state["messages"].append(sys_msg)
251
+
252
+ # Optionally load RAG docs
253
+ if file_paths:
254
+ self.add_files(file_paths)
255
+ self.build_retriever()
256
+
257
+ # Add user question
258
+ state["messages"].append(HumanMessage(content=question))
259
+ if file_paths:
260
+ state["input_file"] = file_paths
261
+
262
+ # Build graph
263
+ builder = StateGraph(dict)
264
+ builder.add_node("assistant", self._assistant_node)
265
+ builder.add_node("tools", ToolNode(self.tools))
266
+ builder.add_edge(START, "assistant")
267
+ builder.add_conditional_edges(
268
+ "assistant",
269
+ lambda s: any(t.__name__ in s["messages"][-1].content for t in self.tools),
270
+ "tools"
271
+ )
272
+ builder.add_edge("tools", "assistant")
273
+ graph = builder.compile()
274
+
275
+ # Run graph until completion
276
+ out = graph.run(state)
277
+ return out["messages"][-1].content
278
+
279
+ def _assistant_node(self, state: dict) -> dict:
280
+ # Invoke LLM on current messages
281
+ resp = self.llm.invoke(state["messages"])
282
+ state["messages"].append(resp)
283
+ return state
284
+
285
 
286
  # CLI entrypoint
287
  if __name__ == "__main__":