Update agent.py
Browse files
agent.py
CHANGED
|
@@ -12,6 +12,7 @@ from langchain_core.tools import tool
|
|
| 12 |
from langchain_community.document_loaders import WikipediaLoader
|
| 13 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 14 |
from langchain_tavily import TavilySearch
|
|
|
|
| 15 |
|
| 16 |
load_dotenv()
|
| 17 |
|
|
@@ -104,16 +105,54 @@ class LangGraphAgent:
|
|
| 104 |
self.graph = build_graph()
|
| 105 |
print("LangGraphAgent initialized with tools.")
|
| 106 |
|
| 107 |
-
def __call__(self, question: str) ->
|
| 108 |
-
"""Run the agent on a question and return
|
| 109 |
try:
|
| 110 |
messages = [HumanMessage(content=question)]
|
| 111 |
result = self.graph.invoke({"messages": messages})
|
|
|
|
|
|
|
| 112 |
for m in result["messages"]:
|
| 113 |
m.pretty_print()
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
except Exception as e:
|
| 116 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
if __name__ == "__main__":
|
| 119 |
agent = LangGraphAgent()
|
|
|
|
| 12 |
from langchain_community.document_loaders import WikipediaLoader
|
| 13 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 14 |
from langchain_tavily import TavilySearch
|
| 15 |
+
import json
|
| 16 |
|
| 17 |
load_dotenv()
|
| 18 |
|
|
|
|
| 105 |
self.graph = build_graph()
|
| 106 |
print("LangGraphAgent initialized with tools.")
|
| 107 |
|
| 108 |
+
def __call__(self, question: str, task_id: str = None) -> dict:
|
| 109 |
+
"""Run the agent on a question and return structured answer with reasoning trace"""
|
| 110 |
try:
|
| 111 |
messages = [HumanMessage(content=question)]
|
| 112 |
result = self.graph.invoke({"messages": messages})
|
| 113 |
+
|
| 114 |
+
# Print all messages for debugging
|
| 115 |
for m in result["messages"]:
|
| 116 |
m.pretty_print()
|
| 117 |
+
|
| 118 |
+
# Extract the final answer and build reasoning trace
|
| 119 |
+
final_answer = result["messages"][-1].content
|
| 120 |
+
|
| 121 |
+
# Build reasoning trace from all messages
|
| 122 |
+
reasoning_steps = []
|
| 123 |
+
for i, msg in enumerate(result["messages"]):
|
| 124 |
+
if isinstance(msg, SystemMessage):
|
| 125 |
+
reasoning_steps.append(f"Step {i+1}: System prompt loaded for ReAct methodology")
|
| 126 |
+
elif isinstance(msg, HumanMessage):
|
| 127 |
+
reasoning_steps.append(f"Step {i+1}: Received question: {msg.content}")
|
| 128 |
+
elif isinstance(msg, AIMessage):
|
| 129 |
+
if msg.tool_calls:
|
| 130 |
+
for tool_call in msg.tool_calls:
|
| 131 |
+
reasoning_steps.append(f"Step {i+1}: Called tool '{tool_call['name']}' with args: {tool_call['args']}")
|
| 132 |
+
else:
|
| 133 |
+
reasoning_steps.append(f"Step {i+1}: AI reasoning: {msg.content}")
|
| 134 |
+
else:
|
| 135 |
+
reasoning_steps.append(f"Step {i+1}: Tool response: {str(msg.content)[:200]}...")
|
| 136 |
+
|
| 137 |
+
reasoning_trace = " | ".join(reasoning_steps)
|
| 138 |
+
|
| 139 |
+
# Extract the final answer from the AI response
|
| 140 |
+
model_answer = final_answer
|
| 141 |
+
if "FINAL ANSWER:" in final_answer:
|
| 142 |
+
model_answer = final_answer.split("FINAL ANSWER:")[-1].strip()
|
| 143 |
+
|
| 144 |
+
return {
|
| 145 |
+
"task_id": task_id or "unknown",
|
| 146 |
+
"model_answer": model_answer,
|
| 147 |
+
"reasoning_trace": reasoning_trace
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
except Exception as e:
|
| 151 |
+
return {
|
| 152 |
+
"task_id": task_id or "unknown",
|
| 153 |
+
"model_answer": f"Error: {str(e)}",
|
| 154 |
+
"reasoning_trace": f"Error occurred during processing: {str(e)}"
|
| 155 |
+
}
|
| 156 |
|
| 157 |
if __name__ == "__main__":
|
| 158 |
agent = LangGraphAgent()
|