Tomkuijpers2232 commited on
Commit
922ff2c
·
verified ·
1 Parent(s): a65ac63

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +43 -4
agent.py CHANGED
@@ -12,6 +12,7 @@ from langchain_core.tools import tool
12
  from langchain_community.document_loaders import WikipediaLoader
13
  from langchain_google_genai import ChatGoogleGenerativeAI
14
  from langchain_tavily import TavilySearch
 
15
 
16
  load_dotenv()
17
 
@@ -104,16 +105,54 @@ class LangGraphAgent:
104
  self.graph = build_graph()
105
  print("LangGraphAgent initialized with tools.")
106
 
107
- def __call__(self, question: str) -> str:
108
- """Run the agent on a question and return the answer"""
109
  try:
110
  messages = [HumanMessage(content=question)]
111
  result = self.graph.invoke({"messages": messages})
 
 
112
  for m in result["messages"]:
113
  m.pretty_print()
114
- return result["messages"][-1].content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  except Exception as e:
116
- return f"Error: {str(e)}"
 
 
 
 
117
 
118
  if __name__ == "__main__":
119
  agent = LangGraphAgent()
 
12
  from langchain_community.document_loaders import WikipediaLoader
13
  from langchain_google_genai import ChatGoogleGenerativeAI
14
  from langchain_tavily import TavilySearch
15
+ import json
16
 
17
  load_dotenv()
18
 
 
105
  self.graph = build_graph()
106
  print("LangGraphAgent initialized with tools.")
107
 
108
+ def __call__(self, question: str, task_id: str = None) -> dict:
109
+ """Run the agent on a question and return structured answer with reasoning trace"""
110
  try:
111
  messages = [HumanMessage(content=question)]
112
  result = self.graph.invoke({"messages": messages})
113
+
114
+ # Print all messages for debugging
115
  for m in result["messages"]:
116
  m.pretty_print()
117
+
118
+ # Extract the final answer and build reasoning trace
119
+ final_answer = result["messages"][-1].content
120
+
121
+ # Build reasoning trace from all messages
122
+ reasoning_steps = []
123
+ for i, msg in enumerate(result["messages"]):
124
+ if isinstance(msg, SystemMessage):
125
+ reasoning_steps.append(f"Step {i+1}: System prompt loaded for ReAct methodology")
126
+ elif isinstance(msg, HumanMessage):
127
+ reasoning_steps.append(f"Step {i+1}: Received question: {msg.content}")
128
+ elif isinstance(msg, AIMessage):
129
+ if msg.tool_calls:
130
+ for tool_call in msg.tool_calls:
131
+ reasoning_steps.append(f"Step {i+1}: Called tool '{tool_call['name']}' with args: {tool_call['args']}")
132
+ else:
133
+ reasoning_steps.append(f"Step {i+1}: AI reasoning: {msg.content}")
134
+ else:
135
+ reasoning_steps.append(f"Step {i+1}: Tool response: {str(msg.content)[:200]}...")
136
+
137
+ reasoning_trace = " | ".join(reasoning_steps)
138
+
139
+ # Extract the final answer from the AI response
140
+ model_answer = final_answer
141
+ if "FINAL ANSWER:" in final_answer:
142
+ model_answer = final_answer.split("FINAL ANSWER:")[-1].strip()
143
+
144
+ return {
145
+ "task_id": task_id or "unknown",
146
+ "model_answer": model_answer,
147
+ "reasoning_trace": reasoning_trace
148
+ }
149
+
150
  except Exception as e:
151
+ return {
152
+ "task_id": task_id or "unknown",
153
+ "model_answer": f"Error: {str(e)}",
154
+ "reasoning_trace": f"Error occurred during processing: {str(e)}"
155
+ }
156
 
157
  if __name__ == "__main__":
158
  agent = LangGraphAgent()