grshot commited on
Commit
890d080
·
1 Parent(s): e628580

refactor tool validations

Browse files
Files changed (2) hide show
  1. agent.py +33 -2
  2. app.py +78 -14
agent.py CHANGED
@@ -217,8 +217,39 @@ def build_agent_graph(provider: str = "groq"):
217
  else:
218
  return {"messages": [system_prompt] + state["messages"]}
219
 
220
- # ToolNode wrapper for actual tool use
221
- tool_node = ToolNode(tools)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
  # Define error handling node
224
  def error_handler_node(state: MessagesState) -> dict:
 
217
  else:
218
  return {"messages": [system_prompt] + state["messages"]}
219
 
220
+ # Wrap tools with validation
221
+ def wrap_tool_with_validation(tool):
222
+ original_func = tool.__call__
223
+
224
+ def validated_call(*args, **kwargs):
225
+ response = original_func(*args, **kwargs)
226
+
227
+ try:
228
+ if not isinstance(response, dict):
229
+ raise ValueError(
230
+ f"Tool response must be a dict, got {type(response)}"
231
+ )
232
+
233
+ # Check for common response keys
234
+ for key in ["web_results", "wiki_results", "transcript_results"]:
235
+ if key in response:
236
+ if not isinstance(response[key], str):
237
+ raise ValueError(
238
+ f"Tool response[{key}] must be string, got {type(response[key])}"
239
+ )
240
+ if not response[key].strip():
241
+ raise ValueError(f"Tool response[{key}] is empty")
242
+
243
+ return response
244
+ except Exception as e:
245
+ return {"error": f"Tool response validation failed: {str(e)}"}
246
+
247
+ tool.__call__ = validated_call
248
+ return tool
249
+
250
+ # Apply validation wrapper to each tool
251
+ validated_tools = [wrap_tool_with_validation(tool) for tool in tools]
252
+ tool_node = ToolNode(validated_tools)
253
 
254
  # Define error handling node
255
  def error_handler_node(state: MessagesState) -> dict:
app.py CHANGED
@@ -1,12 +1,13 @@
1
  import inspect
2
  import os
 
3
 
4
  import gradio as gr
5
  import pandas as pd
6
  import requests
7
- from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage
 
8
 
9
- # from langgraph.graph import MessagesState
10
  from agent import build_agent_graph
11
 
12
  # (Keep Constants as is)
@@ -23,12 +24,56 @@ class BasicAgent:
23
 
24
  def __call__(self, question: str) -> str:
25
  print(f"Agent received question (first 50 chars): {question[:50]}...")
26
- # Wrap the question from HumanMessage from langchain_core
27
- msgs = [HumanMessage(content=question)]
28
- # input_state: MessagesState = {"messages": msgs}
29
- result = self.graph.invoke({"messages": msgs})
30
- answer = result["messages"][-1].content
31
- return answer[14:] # skip "FINAL ANSWER: "
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
 
34
  def run_and_submit_all(profile: gr.OAuthProfile | None):
@@ -62,6 +107,7 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
62
 
63
  # 2. Fetch Questions
64
  print(f"Fetching questions from: {questions_url}")
 
65
  try:
66
  response = requests.get(questions_url, timeout=15)
67
  response.raise_for_status()
@@ -70,13 +116,14 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
70
  print("Fetched questions list is empty.")
71
  return "Fetched questions list is empty or invalid format.", None
72
  print(f"Fetched {len(questions_data)} questions.")
73
- except requests.exceptions.RequestException as e:
74
- print(f"Error fetching questions: {e}")
75
- return f"Error fetching questions: {e}", None
76
  except requests.exceptions.JSONDecodeError as e:
 
77
  print(f"Error decoding JSON response from questions endpoint: {e}")
78
- print(f"Response text: {response.text[:500]}")
79
  return f"Error decoding server response for questions: {e}", None
 
 
 
80
  except Exception as e:
81
  print(f"An unexpected error occurred fetching questions: {e}")
82
  return f"An unexpected error occurred fetching questions: {e}", None
@@ -93,6 +140,17 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
93
  continue
94
  try:
95
  submitted_answer = agent(question_text)
 
 
 
 
 
 
 
 
 
 
 
96
  answers_payload.append(
97
  {"task_id": task_id, "submitted_answer": submitted_answer}
98
  )
@@ -103,13 +161,19 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
103
  "Submitted Answer": submitted_answer,
104
  }
105
  )
 
106
  except Exception as e:
107
- print(f"Error running agent on task {task_id}: {e}")
 
 
 
 
 
108
  results_log.append(
109
  {
110
  "Task ID": task_id,
111
  "Question": question_text,
112
- "Submitted Answer": f"AGENT ERROR: {e}",
113
  }
114
  )
115
 
 
1
  import inspect
2
  import os
3
+ from typing import Dict, List, TypedDict, cast
4
 
5
  import gradio as gr
6
  import pandas as pd
7
  import requests
8
+ from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, SystemMessage
9
+ from langgraph.graph import MessagesState
10
 
 
11
  from agent import build_agent_graph
12
 
13
  # (Keep Constants as is)
 
24
 
25
  def __call__(self, question: str) -> str:
26
  print(f"Agent received question (first 50 chars): {question[:50]}...")
27
+ try:
28
+ # Create properly typed messages
29
+ system_msg = SystemMessage(
30
+ content="""You are a helpful AI assistant. Format your final answer as:
31
+ FINAL ANSWER: [your answer here]"""
32
+ )
33
+ human_msg = HumanMessage(content=question)
34
+ msgs: List[AnyMessage] = [system_msg, human_msg]
35
+
36
+ # Create and cast the state
37
+ input_state = cast(MessagesState, {"messages": msgs})
38
+
39
+ # Invoke the graph
40
+ result = self.graph.invoke(input_state)
41
+
42
+ # Validate response
43
+ if not isinstance(result, dict) or "messages" not in result:
44
+ raise ValueError("Invalid response structure from agent graph")
45
+
46
+ if not result["messages"]:
47
+ raise ValueError("Empty message list in response")
48
+
49
+ # Get the last message content
50
+ last_msg = result["messages"][-1]
51
+ if not isinstance(last_msg, (AIMessage, HumanMessage, SystemMessage)):
52
+ raise ValueError(f"Invalid message type: {type(last_msg)}")
53
+
54
+ answer = last_msg.content
55
+ if not isinstance(answer, str):
56
+ raise ValueError(f"Invalid answer type: {type(answer)}")
57
+
58
+ # Ensure proper formatting
59
+ if not answer.strip():
60
+ return "Error: Empty response from agent"
61
+
62
+ if "FINAL ANSWER:" not in answer:
63
+ # If no prefix, return as is
64
+ return answer
65
+
66
+ # Extract the actual answer after "FINAL ANSWER:"
67
+ final_answer = answer.split("FINAL ANSWER:", 1)[1].strip()
68
+ if not final_answer:
69
+ return "Error: Empty answer after FINAL ANSWER prefix"
70
+
71
+ return final_answer
72
+
73
+ except Exception as e:
74
+ error_msg = f"Error in agent call: {str(e)}"
75
+ print(error_msg)
76
+ return error_msg
77
 
78
 
79
  def run_and_submit_all(profile: gr.OAuthProfile | None):
 
107
 
108
  # 2. Fetch Questions
109
  print(f"Fetching questions from: {questions_url}")
110
+ response = None
111
  try:
112
  response = requests.get(questions_url, timeout=15)
113
  response.raise_for_status()
 
116
  print("Fetched questions list is empty.")
117
  return "Fetched questions list is empty or invalid format.", None
118
  print(f"Fetched {len(questions_data)} questions.")
 
 
 
119
  except requests.exceptions.JSONDecodeError as e:
120
+ error_text = response.text[:500] if response else "No response text available"
121
  print(f"Error decoding JSON response from questions endpoint: {e}")
122
+ print(f"Response text: {error_text}")
123
  return f"Error decoding server response for questions: {e}", None
124
+ except requests.exceptions.RequestException as e:
125
+ print(f"Error fetching questions: {e}")
126
+ return f"Error fetching questions: {e}", None
127
  except Exception as e:
128
  print(f"An unexpected error occurred fetching questions: {e}")
129
  return f"An unexpected error occurred fetching questions: {e}", None
 
140
  continue
141
  try:
142
  submitted_answer = agent(question_text)
143
+ # Remove "FINAL ANSWER: " prefix if present
144
+ if isinstance(submitted_answer, str) and submitted_answer.startswith(
145
+ "FINAL ANSWER: "
146
+ ):
147
+ submitted_answer = submitted_answer[14:]
148
+
149
+ # Handle empty or invalid answers
150
+ if not submitted_answer:
151
+ print(f"Warning: Empty answer for task {task_id}")
152
+ submitted_answer = "Error: Agent produced empty response"
153
+
154
  answers_payload.append(
155
  {"task_id": task_id, "submitted_answer": submitted_answer}
156
  )
 
161
  "Submitted Answer": submitted_answer,
162
  }
163
  )
164
+ print(f"Successfully processed task {task_id}")
165
  except Exception as e:
166
+ error_msg = f"Error running agent on task {task_id}: {str(e)}"
167
+ print(error_msg)
168
+ error_answer = f"Error: {str(e)}"
169
+ answers_payload.append(
170
+ {"task_id": task_id, "submitted_answer": error_answer}
171
+ )
172
  results_log.append(
173
  {
174
  "Task ID": task_id,
175
  "Question": question_text,
176
+ "Submitted Answer": error_answer,
177
  }
178
  )
179