agercas commited on
Commit
2a3616b
·
1 Parent(s): b6088cd
Files changed (2) hide show
  1. pyproject.toml +1 -0
  2. src/agents/langgraph_agent.py +73 -85
pyproject.toml CHANGED
@@ -33,6 +33,7 @@ ruff = "^0.11.12"
33
  [tool.ruff]
34
  line-length = 120
35
  target-version = "py312"
 
36
 
37
  [tool.ruff.lint]
38
  select = ["I", "E4", "E7", "E9", "F", "B", "UP"]
 
33
  [tool.ruff]
34
  line-length = 120
35
  target-version = "py312"
36
+ exclude = ["notebooks/", "*.ipynb"]
37
 
38
  [tool.ruff.lint]
39
  select = ["I", "E4", "E7", "E9", "F", "B", "UP"]
src/agents/langgraph_agent.py CHANGED
@@ -1,22 +1,22 @@
1
- from typing import Annotated, Sequence, TypedDict, Literal
2
- from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, ToolMessage
3
- from langchain_core.runnables import RunnableConfig
4
- from langgraph.graph.message import add_messages
5
- from langgraph.graph import StateGraph, END
6
- from pydantic import BaseModel, Field
7
  from langchain.chat_models import init_chat_model
8
- import json
9
 
10
  # Import tools
11
- from langchain_community.tools import DuckDuckGoSearchRun
 
12
  from langchain_community.tools.pubmed.tool import PubmedQueryRun
13
  from langchain_community.tools.semanticscholar.tool import SemanticScholarQueryRun
14
- from langchain_community.tools.arxiv import ArxivQueryRun
15
  from langchain_community.tools.wikidata.tool import WikidataQueryRun
16
- from langchain_community.tools import WikipediaQueryRun
17
  from langchain_community.utilities import WikipediaAPIWrapper
18
- from langchain_experimental.utilities import PythonREPL
 
19
  from langchain_core.tools import Tool
 
 
 
 
20
 
21
  # Set up tools
22
  python_repl = PythonREPL()
@@ -34,7 +34,7 @@ tools = [
34
  ArxivQueryRun(),
35
  WikidataQueryRun(),
36
  WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()),
37
- repl_tool
38
  ]
39
 
40
  # Initialize Gemini model
@@ -44,39 +44,46 @@ model_with_tools = model.bind_tools(tools)
44
  # Create tools lookup
45
  tools_by_name = {tool.name: tool for tool in tools}
46
 
 
47
  # Pydantic models for structured output
48
  class ToolSufficiencyResponse(BaseModel):
49
  """Response for tool sufficiency check"""
 
50
  sufficient: bool = Field(description="Whether the available tools are sufficient to answer the question")
51
  reasoning: str = Field(description="Brief reasoning for the decision")
52
 
 
53
  class FinalAnswer(BaseModel):
54
  """Final answer structure"""
 
55
  answer: str = Field(description="The comprehensive answer to the user's question")
56
  confidence: Literal["high", "medium", "low"] = Field(description="Confidence level in the answer")
57
  sources_used: list[str] = Field(description="List of tools/sources that were used to generate the answer")
58
 
 
59
  # Define graph state
60
  class AgentState(TypedDict):
61
  """The state of the agent."""
 
62
  messages: Annotated[Sequence[BaseMessage], add_messages]
63
  llm_call_count: int
64
  max_llm_calls: int
65
 
 
66
  # Node functions
67
  def check_tool_sufficiency(state: AgentState, config: RunnableConfig):
68
  """Check if available tools are sufficient to answer the question"""
69
-
70
  # Get the user's question
71
  user_message = None
72
  for msg in state["messages"]:
73
  if msg.type == "human":
74
  user_message = msg.content
75
  break
76
-
77
  # Create system prompt for sufficiency check
78
  available_tools_desc = "\n".join([f"- {tool.name}: {tool.description}" for tool in tools])
79
-
80
  system_prompt = f"""You are an AI assistant that needs to determine if the available tools are sufficient to answer a user's question.
81
 
82
  Available tools:
@@ -93,27 +100,22 @@ Be generous in your assessment - if there's a reasonable path to answer the ques
93
 
94
  # Use structured output for sufficiency check
95
  structured_model = model.with_structured_output(ToolSufficiencyResponse)
96
-
97
- messages = [
98
- SystemMessage(content=system_prompt),
99
- HumanMessage(content=f"Question to analyze: {user_message}")
100
- ]
101
-
102
  response = structured_model.invoke(messages, config)
103
-
104
  # Add response to messages for context
105
  response_message = SystemMessage(
106
  content=f"Tool sufficiency check: {'Sufficient' if response.sufficient else 'Insufficient'}. Reasoning: {response.reasoning}"
107
  )
108
-
109
- return {
110
- "messages": [response_message],
111
- "tool_sufficiency": response.sufficient
112
- }
113
 
114
  def call_model(state: AgentState, config: RunnableConfig):
115
  """Call the model (ReAct agent LLM node)"""
116
-
117
  system_prompt = SystemMessage(
118
  content="""You are a helpful AI assistant with access to various tools. Use the tools available to you to answer the user's question comprehensively.
119
 
@@ -124,22 +126,20 @@ Think step by step:
124
 
125
  Be thorough but efficient with your tool usage."""
126
  )
127
-
128
  response = model_with_tools.invoke([system_prompt] + state["messages"], config)
129
-
130
  # Increment LLM call count
131
  new_count = state.get("llm_call_count", 0) + 1
132
-
133
- return {
134
- "messages": [response],
135
- "llm_call_count": new_count
136
- }
137
 
138
  def tool_node(state: AgentState):
139
  """Execute tools based on the last message's tool calls"""
140
  outputs = []
141
  last_message = state["messages"][-1]
142
-
143
  for tool_call in last_message.tool_calls:
144
  try:
145
  tool_result = tools_by_name[tool_call["name"]].invoke(tool_call["args"])
@@ -158,12 +158,13 @@ def tool_node(state: AgentState):
158
  tool_call_id=tool_call["id"],
159
  )
160
  )
161
-
162
  return {"messages": outputs}
163
 
 
164
  def final_answer_node(state: AgentState, config: RunnableConfig):
165
  """Generate final structured answer based on conversation history"""
166
-
167
  system_prompt = SystemMessage(
168
  content="""You are tasked with providing a final, comprehensive answer based on the conversation history and tool usage.
169
 
@@ -174,29 +175,27 @@ Analyze all the information gathered from the tools and provide:
174
 
175
  Be honest about limitations and indicate your confidence level appropriately."""
176
  )
177
-
178
  # Get the original user question
179
  user_question = None
180
  for msg in state["messages"]:
181
  if msg.type == "human":
182
  user_question = msg.content
183
  break
184
-
185
  # Create structured output model
186
  structured_model = model.with_structured_output(FinalAnswer)
187
-
188
  messages = [
189
  system_prompt,
190
  HumanMessage(content=f"Original question: {user_question}"),
191
- SystemMessage(content="Based on the following conversation history, provide your final answer:")
192
  ] + state["messages"]
193
-
194
  response = structured_model.invoke(messages, config)
195
-
196
- return {
197
- "messages": [SystemMessage(content=f"Final Answer: {response.answer}")],
198
- "final_answer": response
199
- }
200
 
201
  # Edge functions
202
  def should_continue_sufficiency(state: AgentState):
@@ -209,95 +208,84 @@ def should_continue_sufficiency(state: AgentState):
209
  return "insufficient"
210
  return "insufficient" # Default to insufficient if unclear
211
 
 
212
  def should_continue_react(state: AgentState):
213
  """Decide whether to continue with ReAct loop or move to final answer"""
214
  messages = state["messages"]
215
  last_message = messages[-1]
216
  llm_call_count = state.get("llm_call_count", 0)
217
  max_calls = state.get("max_llm_calls", 4)
218
-
219
  # If we've reached the maximum number of LLM calls, force stop
220
  if llm_call_count >= max_calls:
221
  return "final_answer"
222
-
223
  # If there are no tool calls, we're done with ReAct loop
224
- if not hasattr(last_message, 'tool_calls') or not last_message.tool_calls:
225
  return "final_answer"
226
-
227
  # Otherwise continue with tools
228
  return "continue"
229
 
 
230
  # Build the graph
231
  def create_react_agent_graph():
232
  """Create and return the compiled ReAct agent graph"""
233
-
234
  workflow = StateGraph(AgentState)
235
-
236
  # Add nodes
237
  workflow.add_node("check_sufficiency", check_tool_sufficiency)
238
  workflow.add_node("agent", call_model)
239
  workflow.add_node("tools", tool_node)
240
  workflow.add_node("final_answer", final_answer_node)
241
-
242
  # Set entry point
243
  workflow.set_entry_point("check_sufficiency")
244
-
245
  # Add conditional edge from sufficiency check
246
  workflow.add_conditional_edges(
247
- "check_sufficiency",
248
- should_continue_sufficiency,
249
- {
250
- "sufficient": "agent",
251
- "insufficient": END
252
- }
253
  )
254
-
255
  # Add conditional edge from agent
256
  workflow.add_conditional_edges(
257
- "agent",
258
- should_continue_react,
259
- {
260
- "continue": "tools",
261
- "final_answer": "final_answer"
262
- }
263
  )
264
-
265
  # Add edge from tools back to agent
266
  workflow.add_edge("tools", "agent")
267
-
268
  # Add edge from final_answer to END
269
  workflow.add_edge("final_answer", END)
270
-
271
  return workflow.compile()
272
 
 
273
  # Helper function for running the agent
274
  def run_agent(question: str, max_llm_calls: int = 4):
275
  """Run the ReAct agent with a question"""
276
-
277
  graph = create_react_agent_graph()
278
-
279
- initial_state = {
280
- "messages": [HumanMessage(content=question)],
281
- "llm_call_count": 0,
282
- "max_llm_calls": max_llm_calls
283
- }
284
-
285
  # Stream the execution
286
  print(f"Question: {question}")
287
  print("=" * 50)
288
-
289
  for step in graph.stream(initial_state):
290
  for node, output in step.items():
291
  print(f"\n--- {node.upper()} ---")
292
  if "messages" in output and output["messages"]:
293
  for msg in output["messages"]:
294
- if hasattr(msg, 'content'):
295
  print(f"{msg.__class__.__name__}: {msg.content}")
296
- elif hasattr(msg, 'tool_calls') and msg.tool_calls:
297
  print(f"Tool calls: {[tc['name'] for tc in msg.tool_calls]}")
298
-
299
  if "final_answer" in output:
300
- print(f"\nFINAL STRUCTURED ANSWER:")
301
  print(f"Answer: {output['final_answer'].answer}")
302
  print(f"Confidence: {output['final_answer'].confidence}")
303
  print(f"Sources: {output['final_answer'].sources_used}")
 
1
+ from collections.abc import Sequence
2
+ from typing import Annotated, Literal, TypedDict
3
+
 
 
 
4
  from langchain.chat_models import init_chat_model
 
5
 
6
  # Import tools
7
+ from langchain_community.tools import DuckDuckGoSearchRun, WikipediaQueryRun
8
+ from langchain_community.tools.arxiv import ArxivQueryRun
9
  from langchain_community.tools.pubmed.tool import PubmedQueryRun
10
  from langchain_community.tools.semanticscholar.tool import SemanticScholarQueryRun
 
11
  from langchain_community.tools.wikidata.tool import WikidataQueryRun
 
12
  from langchain_community.utilities import WikipediaAPIWrapper
13
+ from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, ToolMessage
14
+ from langchain_core.runnables import RunnableConfig
15
  from langchain_core.tools import Tool
16
+ from langchain_experimental.utilities import PythonREPL
17
+ from langgraph.graph import END, StateGraph
18
+ from langgraph.graph.message import add_messages
19
+ from pydantic import BaseModel, Field
20
 
21
  # Set up tools
22
  python_repl = PythonREPL()
 
34
  ArxivQueryRun(),
35
  WikidataQueryRun(),
36
  WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()),
37
+ repl_tool,
38
  ]
39
 
40
  # Initialize Gemini model
 
44
  # Create tools lookup
45
  tools_by_name = {tool.name: tool for tool in tools}
46
 
47
+
48
  # Pydantic models for structured output
49
  class ToolSufficiencyResponse(BaseModel):
50
  """Response for tool sufficiency check"""
51
+
52
  sufficient: bool = Field(description="Whether the available tools are sufficient to answer the question")
53
  reasoning: str = Field(description="Brief reasoning for the decision")
54
 
55
+
56
  class FinalAnswer(BaseModel):
57
  """Final answer structure"""
58
+
59
  answer: str = Field(description="The comprehensive answer to the user's question")
60
  confidence: Literal["high", "medium", "low"] = Field(description="Confidence level in the answer")
61
  sources_used: list[str] = Field(description="List of tools/sources that were used to generate the answer")
62
 
63
+
64
  # Define graph state
65
  class AgentState(TypedDict):
66
  """The state of the agent."""
67
+
68
  messages: Annotated[Sequence[BaseMessage], add_messages]
69
  llm_call_count: int
70
  max_llm_calls: int
71
 
72
+
73
  # Node functions
74
  def check_tool_sufficiency(state: AgentState, config: RunnableConfig):
75
  """Check if available tools are sufficient to answer the question"""
76
+
77
  # Get the user's question
78
  user_message = None
79
  for msg in state["messages"]:
80
  if msg.type == "human":
81
  user_message = msg.content
82
  break
83
+
84
  # Create system prompt for sufficiency check
85
  available_tools_desc = "\n".join([f"- {tool.name}: {tool.description}" for tool in tools])
86
+
87
  system_prompt = f"""You are an AI assistant that needs to determine if the available tools are sufficient to answer a user's question.
88
 
89
  Available tools:
 
100
 
101
  # Use structured output for sufficiency check
102
  structured_model = model.with_structured_output(ToolSufficiencyResponse)
103
+
104
+ messages = [SystemMessage(content=system_prompt), HumanMessage(content=f"Question to analyze: {user_message}")]
105
+
 
 
 
106
  response = structured_model.invoke(messages, config)
107
+
108
  # Add response to messages for context
109
  response_message = SystemMessage(
110
  content=f"Tool sufficiency check: {'Sufficient' if response.sufficient else 'Insufficient'}. Reasoning: {response.reasoning}"
111
  )
112
+
113
+ return {"messages": [response_message], "tool_sufficiency": response.sufficient}
114
+
 
 
115
 
116
  def call_model(state: AgentState, config: RunnableConfig):
117
  """Call the model (ReAct agent LLM node)"""
118
+
119
  system_prompt = SystemMessage(
120
  content="""You are a helpful AI assistant with access to various tools. Use the tools available to you to answer the user's question comprehensively.
121
 
 
126
 
127
  Be thorough but efficient with your tool usage."""
128
  )
129
+
130
  response = model_with_tools.invoke([system_prompt] + state["messages"], config)
131
+
132
  # Increment LLM call count
133
  new_count = state.get("llm_call_count", 0) + 1
134
+
135
+ return {"messages": [response], "llm_call_count": new_count}
136
+
 
 
137
 
138
  def tool_node(state: AgentState):
139
  """Execute tools based on the last message's tool calls"""
140
  outputs = []
141
  last_message = state["messages"][-1]
142
+
143
  for tool_call in last_message.tool_calls:
144
  try:
145
  tool_result = tools_by_name[tool_call["name"]].invoke(tool_call["args"])
 
158
  tool_call_id=tool_call["id"],
159
  )
160
  )
161
+
162
  return {"messages": outputs}
163
 
164
+
165
  def final_answer_node(state: AgentState, config: RunnableConfig):
166
  """Generate final structured answer based on conversation history"""
167
+
168
  system_prompt = SystemMessage(
169
  content="""You are tasked with providing a final, comprehensive answer based on the conversation history and tool usage.
170
 
 
175
 
176
  Be honest about limitations and indicate your confidence level appropriately."""
177
  )
178
+
179
  # Get the original user question
180
  user_question = None
181
  for msg in state["messages"]:
182
  if msg.type == "human":
183
  user_question = msg.content
184
  break
185
+
186
  # Create structured output model
187
  structured_model = model.with_structured_output(FinalAnswer)
188
+
189
  messages = [
190
  system_prompt,
191
  HumanMessage(content=f"Original question: {user_question}"),
192
+ SystemMessage(content="Based on the following conversation history, provide your final answer:"),
193
  ] + state["messages"]
194
+
195
  response = structured_model.invoke(messages, config)
196
+
197
+ return {"messages": [SystemMessage(content=f"Final Answer: {response.answer}")], "final_answer": response}
198
+
 
 
199
 
200
  # Edge functions
201
  def should_continue_sufficiency(state: AgentState):
 
208
  return "insufficient"
209
  return "insufficient" # Default to insufficient if unclear
210
 
211
+
212
  def should_continue_react(state: AgentState):
213
  """Decide whether to continue with ReAct loop or move to final answer"""
214
  messages = state["messages"]
215
  last_message = messages[-1]
216
  llm_call_count = state.get("llm_call_count", 0)
217
  max_calls = state.get("max_llm_calls", 4)
218
+
219
  # If we've reached the maximum number of LLM calls, force stop
220
  if llm_call_count >= max_calls:
221
  return "final_answer"
222
+
223
  # If there are no tool calls, we're done with ReAct loop
224
+ if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
225
  return "final_answer"
226
+
227
  # Otherwise continue with tools
228
  return "continue"
229
 
230
+
231
  # Build the graph
232
  def create_react_agent_graph():
233
  """Create and return the compiled ReAct agent graph"""
234
+
235
  workflow = StateGraph(AgentState)
236
+
237
  # Add nodes
238
  workflow.add_node("check_sufficiency", check_tool_sufficiency)
239
  workflow.add_node("agent", call_model)
240
  workflow.add_node("tools", tool_node)
241
  workflow.add_node("final_answer", final_answer_node)
242
+
243
  # Set entry point
244
  workflow.set_entry_point("check_sufficiency")
245
+
246
  # Add conditional edge from sufficiency check
247
  workflow.add_conditional_edges(
248
+ "check_sufficiency", should_continue_sufficiency, {"sufficient": "agent", "insufficient": END}
 
 
 
 
 
249
  )
250
+
251
  # Add conditional edge from agent
252
  workflow.add_conditional_edges(
253
+ "agent", should_continue_react, {"continue": "tools", "final_answer": "final_answer"}
 
 
 
 
 
254
  )
255
+
256
  # Add edge from tools back to agent
257
  workflow.add_edge("tools", "agent")
258
+
259
  # Add edge from final_answer to END
260
  workflow.add_edge("final_answer", END)
261
+
262
  return workflow.compile()
263
 
264
+
265
  # Helper function for running the agent
266
  def run_agent(question: str, max_llm_calls: int = 4):
267
  """Run the ReAct agent with a question"""
268
+
269
  graph = create_react_agent_graph()
270
+
271
+ initial_state = {"messages": [HumanMessage(content=question)], "llm_call_count": 0, "max_llm_calls": max_llm_calls}
272
+
 
 
 
 
273
  # Stream the execution
274
  print(f"Question: {question}")
275
  print("=" * 50)
276
+
277
  for step in graph.stream(initial_state):
278
  for node, output in step.items():
279
  print(f"\n--- {node.upper()} ---")
280
  if "messages" in output and output["messages"]:
281
  for msg in output["messages"]:
282
+ if hasattr(msg, "content"):
283
  print(f"{msg.__class__.__name__}: {msg.content}")
284
+ elif hasattr(msg, "tool_calls") and msg.tool_calls:
285
  print(f"Tool calls: {[tc['name'] for tc in msg.tool_calls]}")
286
+
287
  if "final_answer" in output:
288
+ print("\nFINAL STRUCTURED ANSWER:")
289
  print(f"Answer: {output['final_answer'].answer}")
290
  print(f"Confidence: {output['final_answer'].confidence}")
291
  print(f"Sources: {output['final_answer'].sources_used}")