mirjam-m commited on
Commit
d4f92d3
·
1 Parent(s): 13ded9e

graph test

Browse files
Files changed (1) hide show
  1. app.py +60 -24
app.py CHANGED
@@ -74,15 +74,17 @@ class BasicAgent:
74
  self.model = ChatOpenAI(model="gpt-4o", temperature=0)
75
  self.graph = StateGraph(AnswerState)
76
  self.graph.add_node("log_question", self.log_question)
77
- self.graph.add_node("try_answer", self.try_answer)
 
78
  self.graph.add_node("final_answer", self.final_answer)
79
 
80
  self.graph.add_edge(START, "log_question")
 
81
 
82
  # Add conditional edges
83
  self.graph.add_conditional_edges(
84
- "log_question",
85
- self.try_answer,
86
  {
87
  "FINAL_ANSWER": "final_answer",
88
  "GOOGLE_SEARCH": "final_answer",
@@ -124,36 +126,24 @@ class BasicAgent:
124
  "is_final_answer": True,
125
  }
126
 
127
- def try_answer(self, state: AnswerState) -> str:
128
- if state["attempt"] > 3:
 
 
129
  state["answer"] = "Exceeded max number of attempts."
130
  return "FINAL_ANSWER"
131
- else:
132
- state["attempt"] += 1
133
-
134
- print("[try_answer] Agent trying to answer")
135
-
136
- response = self.model.invoke(state["messages"])
137
- print(f"Agent response: {response.content}")
138
- state["messages"].append(
139
- ChatMessage(
140
- role="assistant",
141
- content=response.content,
142
- )
143
- )
144
 
145
- if "FINAL ANSWER:" in response.content:
146
- answer = response.text().split("FINAL ANSWER:")[-1].strip()
 
147
  print(f"Agent final answer: {answer}")
148
  print(f"Agent returning answer: {state['answer']}")
149
  state["answer"] = answer
150
 
151
  return "FINAL_ANSWER"
152
 
153
- if "TOOL: GoogleSearchAgent" in response.content:
154
- request_full = (
155
- response.text().split("TOOL: GoogleSearchAgent(")[-1].strip(")")
156
- )
157
  print(f"Tool invocation request: {request_full}")
158
  request = request_full.split("request=")[-1].strip("'")
159
  print(f"Search request: {request}")
@@ -163,6 +153,52 @@ class BasicAgent:
163
 
164
  return "FINAL_ANSWER"
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  def run_and_submit_all(profile: gr.OAuthProfile | None):
168
  """
 
74
  self.model = ChatOpenAI(model="gpt-4o", temperature=0)
75
  self.graph = StateGraph(AnswerState)
76
  self.graph.add_node("log_question", self.log_question)
77
+ self.graph.add_node("invoke_model", self.invoke_model)
78
+ self.graph.add_node("parse_response", self.parse_response)
79
  self.graph.add_node("final_answer", self.final_answer)
80
 
81
  self.graph.add_edge(START, "log_question")
82
+ self.graph.add_edge("log_question", "invoke_model")
83
 
84
  # Add conditional edges
85
  self.graph.add_conditional_edges(
86
+ "invoke_model",
87
+ self.parse_response,
88
  {
89
  "FINAL_ANSWER": "final_answer",
90
  "GOOGLE_SEARCH": "final_answer",
 
126
  "is_final_answer": True,
127
  }
128
 
129
+ def parse_response(self, state: AnswerState) -> str:
130
+ print("[parse_response] parsing last chat response")
131
+
132
+ if state["attempt"] > 0:
133
  state["answer"] = "Exceeded max number of attempts."
134
  return "FINAL_ANSWER"
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ content = self.content2text(state["messages"][-1].content)
137
+ if "FINAL ANSWER:" in content:
138
+ answer = content.split("FINAL ANSWER:")[-1].strip()
139
  print(f"Agent final answer: {answer}")
140
  print(f"Agent returning answer: {state['answer']}")
141
  state["answer"] = answer
142
 
143
  return "FINAL_ANSWER"
144
 
145
+ if "TOOL: GoogleSearchAgent" in content:
146
+ request_full = content.split("TOOL: GoogleSearchAgent(")[-1].strip(")")
 
 
147
  print(f"Tool invocation request: {request_full}")
148
  request = request_full.split("request=")[-1].strip("'")
149
  print(f"Search request: {request}")
 
153
 
154
  return "FINAL_ANSWER"
155
 
156
+ print("[parse_response] Agent parsing response")
157
+ if state["is_final_answer"]:
158
+ return "FINAL_ANSWER"
159
+
160
+ if state["search_request"]:
161
+ return "GOOGLE_SEARCH"
162
+
163
+ return "FINAL_ANSWER"
164
+
165
+ def content2text(self, content) -> str:
166
+ """Get the text content of the message.
167
+
168
+ Returns:
169
+ The text content of the message.
170
+ """
171
+ if isinstance(content, str):
172
+ return content
173
+
174
+ # must be a list
175
+ blocks = [
176
+ block
177
+ for block in content
178
+ if isinstance(block, str)
179
+ or (block.get("type") == "text" and isinstance(block.get("text"), str))
180
+ ]
181
+ return "".join(
182
+ block if isinstance(block, str) else block["text"] for block in blocks
183
+ )
184
+
185
+ def invoke_model(self, state: AnswerState) -> Dict[str, Any]:
186
+ print("[invoke_model] Agent trying to answer")
187
+
188
+ response = self.model.invoke(state["messages"])
189
+ print(f"Agent response: {response.content}")
190
+ state["messages"].append(
191
+ ChatMessage(
192
+ role="assistant",
193
+ content=response.content,
194
+ )
195
+ )
196
+
197
+ return {
198
+ "messages": state["messages"],
199
+ "attempt": state.get("attempt", 0) + 1,
200
+ }
201
+
202
 
203
  def run_and_submit_all(profile: gr.OAuthProfile | None):
204
  """