alisamak's picture
Update LG_agent.py
b8c3c10 verified
from typing import Annotated, TypedDict
from langgraph.graph.message import add_messages
from langchain_core.messages import HumanMessage, AIMessage, AnyMessage, SystemMessage
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.graph import START, StateGraph
from langchain_openai import ChatOpenAI
from tools import all_tools
import inspect
import os
import re
# 1. Setup once
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
raise ValueError("Missing OPENAI_API_KEY environment variable.")
chat = ChatOpenAI(
model="gpt-3.5-turbo",
openai_api_key=OPENAI_API_KEY,
temperature=0,
)
chat_with_tools = chat.bind_tools(all_tools)
# 2. Define the agent state
class AgentState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
def extract_gaia_answer(text: str) -> str:
"""
Extracts just the final answer in raw form, stripping explanation and prefixes like:
- 'The answer is: ...'
- 'Answer: ...'
- Or just the raw line if short and valid.
"""
patterns = [
r"The answer is:\s*(.+)",
r"Answer:\s*(.+)",
r"^([a-z0-9\s,\-]+)$", # simple raw line (numbers, text)
]
for pattern in patterns:
match = re.search(pattern, text.strip(), re.IGNORECASE | re.MULTILINE)
if match:
return match.group(1).strip().lower()
# Fallback: return first short line if it's probably the answer
lines = [l.strip() for l in text.strip().splitlines() if l.strip()]
if lines and len(lines[0]) < 80:
return lines[0].strip().lower()
# Final fallback: return full text, lowercase
return text.strip().lower()
# 3. Assistant node
def assistant(state: AgentState):
tool_descriptions = "\n".join([
f"{tool.name}{inspect.signature(tool.func)}:\n {tool.description.strip()}"
for tool in all_tools
])
sys_msg = SystemMessage(
content=(
"You are a helpful AI assistant who solves GAIA benchmark questions using step-by-step reasoning.\n"
"Before answering, always think out loud and plan your approach.\n"
"Use tools when you lack information or need external data. Only use audio or transcription tools if the user clearly provides or references an audio file.\n"
"Do not assume the existence of files or media unless they are explicitly mentioned. Do not call tools like transcription unless an actual file or media reference is present.\n"
"After every tool call, always analyze the result and continue reasoning to arrive at a final answer.\n"
"If the question is unclear, incomplete, or missing context, respond with: **'The question is incomplete β€” please provide more information.'**"
"Never treat tool outputs as final β€” interpret them and continue solving the task step-by-step.\n"
"When the question specifies an answer format (e.g., a number, list, or code), respond **only with the final answer** in the required format. Do not add explanations, quotes, or set notation. Output exactly what is requested.\n"
"Finish with a clear and concise answer, such as 'The answer is: right'.\n"
"\nAvailable tools:\n"
f"{tool_descriptions}"
)
)
input_msgs = [sys_msg] + state["messages"]
print("\n🧠 Assistant received messages:")
for msg in input_msgs:
print(f"πŸ”Ή {msg.__class__.__name__}: {getattr(msg, 'content', '')[:200]}")
output = chat_with_tools.invoke(input_msgs)
print("\nπŸ—£οΈ Assistant response:")
print("-" * 40)
print(getattr(output, 'content', '')[:500])
print("-" * 40)
return {
"messages": [output],
}
# 4. Build the agent graph
def build_graph(max_steps: int = 5):
builder = StateGraph(AgentState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(all_tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
graph = builder.compile()
def limited_invoke(state, max_steps: int = 5, max_reasoning_steps_after_tool: int = 2):
steps = 0
reasoning_steps_since_last_tool = 0
while steps < max_steps:
print(f"\U0001f501 Step {steps + 1}")
state = graph.invoke(state)
for msg in state["messages"]:
if isinstance(msg, AIMessage):
print("\nπŸ€– Assistant says:")
print("-" * 40)
print(msg.content.strip())
print("-" * 40)
latest_message = state["messages"][-1] if state["messages"] else None
if isinstance(latest_message, AIMessage):
if latest_message.tool_calls:
print("πŸ”„ Tool call detected β€” continuing loop.")
reasoning_steps_since_last_tool = 0 # reset counter
else:
reasoning_steps_since_last_tool += 1
print(f"🧠 No tool call β€” reasoning step #{reasoning_steps_since_last_tool}")
# πŸ› οΈ Handle reverse_sentence manually
if "reverse_sentence" in latest_message.content.lower():
# Try to find the ToolMessage output
tool_outputs = [msg for msg in state["messages"] if msg.type == "tool"]
if tool_outputs:
reversed_text = tool_outputs[-1].content.strip()
print(f"πŸ” Re-feeding reversed message:\n{reversed_text}")
state["messages"].append(HumanMessage(content=reversed_text))
continue # loop again with new input
if reasoning_steps_since_last_tool >= max_reasoning_steps_after_tool:
print("βœ… Final answer assumed after sufficient reasoning.")
break
steps += 1
return state
return limited_invoke
# 5. BasicAgent class
# class BasicAgent:
# def __init__(self, max_steps: int = 5):
# self.graph = build_graph(max_steps)
# def __call__(self, question: str) -> str:
# response = self.graph({"messages": [HumanMessage(content=question)]})
# if response.get("messages"):
# final_message = response["messages"][-1]
# return final_message.content if hasattr(final_message, "content") else "No final message."
# else:
# return "No response."
class BasicAgent:
def __init__(self, max_steps: int = 5):
self.graph = build_graph(max_steps)
def __call__(self, question: str) -> str:
response = self.graph({"messages": [HumanMessage(content=question)]})
if response.get("messages"):
final_message = response["messages"][-1]
raw_content = final_message.content if hasattr(final_message, "content") else "No final message."
return extract_gaia_answer(raw_content)
else:
return "No response."
if __name__ == "__main__":
agent = BasicAgent()
print(agent("What is the capital of France?"))