Spaces:
Runtime error
Runtime error
| from typing import Annotated, TypedDict | |
| from langgraph.graph.message import add_messages | |
| from langchain_core.messages import HumanMessage, AIMessage, AnyMessage, SystemMessage | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from langgraph.graph import START, StateGraph | |
| from langchain_openai import ChatOpenAI | |
| from tools import all_tools | |
| import inspect | |
| import os | |
| import re | |
| # 1. Setup once | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| if not OPENAI_API_KEY: | |
| raise ValueError("Missing OPENAI_API_KEY environment variable.") | |
| chat = ChatOpenAI( | |
| model="gpt-3.5-turbo", | |
| openai_api_key=OPENAI_API_KEY, | |
| temperature=0, | |
| ) | |
| chat_with_tools = chat.bind_tools(all_tools) | |
| # 2. Define the agent state | |
| class AgentState(TypedDict): | |
| messages: Annotated[list[AnyMessage], add_messages] | |
| def extract_gaia_answer(text: str) -> str: | |
| """ | |
| Extracts just the final answer in raw form, stripping explanation and prefixes like: | |
| - 'The answer is: ...' | |
| - 'Answer: ...' | |
| - Or just the raw line if short and valid. | |
| """ | |
| patterns = [ | |
| r"The answer is:\s*(.+)", | |
| r"Answer:\s*(.+)", | |
| r"^([a-z0-9\s,\-]+)$", # simple raw line (numbers, text) | |
| ] | |
| for pattern in patterns: | |
| match = re.search(pattern, text.strip(), re.IGNORECASE | re.MULTILINE) | |
| if match: | |
| return match.group(1).strip().lower() | |
| # Fallback: return first short line if it's probably the answer | |
| lines = [l.strip() for l in text.strip().splitlines() if l.strip()] | |
| if lines and len(lines[0]) < 80: | |
| return lines[0].strip().lower() | |
| # Final fallback: return full text, lowercase | |
| return text.strip().lower() | |
| # 3. Assistant node | |
| def assistant(state: AgentState): | |
| tool_descriptions = "\n".join([ | |
| f"{tool.name}{inspect.signature(tool.func)}:\n {tool.description.strip()}" | |
| for tool in all_tools | |
| ]) | |
| sys_msg = SystemMessage( | |
| content=( | |
| "You are a helpful AI assistant who solves GAIA benchmark questions using step-by-step reasoning.\n" | |
| "Before answering, always think out loud and plan your approach.\n" | |
| "Use tools when you lack information or need external data. Only use audio or transcription tools if the user clearly provides or references an audio file.\n" | |
| "Do not assume the existence of files or media unless they are explicitly mentioned. Do not call tools like transcription unless an actual file or media reference is present.\n" | |
| "After every tool call, always analyze the result and continue reasoning to arrive at a final answer.\n" | |
| "If the question is unclear, incomplete, or missing context, respond with: **'The question is incomplete β please provide more information.'**" | |
| "Never treat tool outputs as final β interpret them and continue solving the task step-by-step.\n" | |
| "When the question specifies an answer format (e.g., a number, list, or code), respond **only with the final answer** in the required format. Do not add explanations, quotes, or set notation. Output exactly what is requested.\n" | |
| "Finish with a clear and concise answer, such as 'The answer is: right'.\n" | |
| "\nAvailable tools:\n" | |
| f"{tool_descriptions}" | |
| ) | |
| ) | |
| input_msgs = [sys_msg] + state["messages"] | |
| print("\nπ§ Assistant received messages:") | |
| for msg in input_msgs: | |
| print(f"πΉ {msg.__class__.__name__}: {getattr(msg, 'content', '')[:200]}") | |
| output = chat_with_tools.invoke(input_msgs) | |
| print("\nπ£οΈ Assistant response:") | |
| print("-" * 40) | |
| print(getattr(output, 'content', '')[:500]) | |
| print("-" * 40) | |
| return { | |
| "messages": [output], | |
| } | |
| # 4. Build the agent graph | |
| def build_graph(max_steps: int = 5): | |
| builder = StateGraph(AgentState) | |
| builder.add_node("assistant", assistant) | |
| builder.add_node("tools", ToolNode(all_tools)) | |
| builder.add_edge(START, "assistant") | |
| builder.add_conditional_edges("assistant", tools_condition) | |
| builder.add_edge("tools", "assistant") | |
| graph = builder.compile() | |
| def limited_invoke(state, max_steps: int = 5, max_reasoning_steps_after_tool: int = 2): | |
| steps = 0 | |
| reasoning_steps_since_last_tool = 0 | |
| while steps < max_steps: | |
| print(f"\U0001f501 Step {steps + 1}") | |
| state = graph.invoke(state) | |
| for msg in state["messages"]: | |
| if isinstance(msg, AIMessage): | |
| print("\nπ€ Assistant says:") | |
| print("-" * 40) | |
| print(msg.content.strip()) | |
| print("-" * 40) | |
| latest_message = state["messages"][-1] if state["messages"] else None | |
| if isinstance(latest_message, AIMessage): | |
| if latest_message.tool_calls: | |
| print("π Tool call detected β continuing loop.") | |
| reasoning_steps_since_last_tool = 0 # reset counter | |
| else: | |
| reasoning_steps_since_last_tool += 1 | |
| print(f"π§ No tool call β reasoning step #{reasoning_steps_since_last_tool}") | |
| # π οΈ Handle reverse_sentence manually | |
| if "reverse_sentence" in latest_message.content.lower(): | |
| # Try to find the ToolMessage output | |
| tool_outputs = [msg for msg in state["messages"] if msg.type == "tool"] | |
| if tool_outputs: | |
| reversed_text = tool_outputs[-1].content.strip() | |
| print(f"π Re-feeding reversed message:\n{reversed_text}") | |
| state["messages"].append(HumanMessage(content=reversed_text)) | |
| continue # loop again with new input | |
| if reasoning_steps_since_last_tool >= max_reasoning_steps_after_tool: | |
| print("β Final answer assumed after sufficient reasoning.") | |
| break | |
| steps += 1 | |
| return state | |
| return limited_invoke | |
| # 5. BasicAgent class | |
| # class BasicAgent: | |
| # def __init__(self, max_steps: int = 5): | |
| # self.graph = build_graph(max_steps) | |
| # def __call__(self, question: str) -> str: | |
| # response = self.graph({"messages": [HumanMessage(content=question)]}) | |
| # if response.get("messages"): | |
| # final_message = response["messages"][-1] | |
| # return final_message.content if hasattr(final_message, "content") else "No final message." | |
| # else: | |
| # return "No response." | |
| class BasicAgent: | |
| def __init__(self, max_steps: int = 5): | |
| self.graph = build_graph(max_steps) | |
| def __call__(self, question: str) -> str: | |
| response = self.graph({"messages": [HumanMessage(content=question)]}) | |
| if response.get("messages"): | |
| final_message = response["messages"][-1] | |
| raw_content = final_message.content if hasattr(final_message, "content") else "No final message." | |
| return extract_gaia_answer(raw_content) | |
| else: | |
| return "No response." | |
| if __name__ == "__main__": | |
| agent = BasicAgent() | |
| print(agent("What is the capital of France?")) | |