Spaces:
Sleeping
Sleeping
| # agent.py | |
| import os | |
| import logging | |
| from typing import TypedDict, Annotated, Any | |
| from langgraph.graph import StateGraph, END, START | |
| from langgraph.graph.message import add_messages | |
| from dotenv import load_dotenv | |
| from langgraph.prebuilt import ToolNode | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, ToolMessage, SystemMessage | |
| from tools import TOOLS # Your tools list should be defined here | |
| import requests | |
| import re | |
| import json | |
| # --- Logging Setup --- | |
| load_dotenv() | |
| LOG_FILE = os.path.join(os.path.dirname(__file__), "agent.log") | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(message)s", | |
| handlers=[ | |
| logging.StreamHandler(), | |
| logging.FileHandler(LOG_FILE, mode="w", encoding="utf-8"), | |
| ], | |
| ) | |
| logger = logging.getLogger("agent_logger") | |
| # --- Token Counting Helper --- | |
| def count_tokens(messages): | |
| try: | |
| import tiktoken | |
| enc = tiktoken.encoding_for_model("gpt-3.5-turbo") | |
| total = 0 | |
| for msg in messages: | |
| if hasattr(msg, "content") and msg.content: | |
| total += len(enc.encode(str(msg.content))) | |
| return total | |
| except ImportError: | |
| logger.warning("tiktoken not installed, skipping token count.") | |
| return -1 | |
| except Exception as e: | |
| logger.warning(f"Token counting error: {e}") | |
| return -1 | |
| # LLM definition using GPT‑o3 | |
| system_prompt = ( | |
| "You are a helpful assistant. The current year is 2025. When answering, output ONLY the answer to the question, with no extra text, explanation, or formatting. " | |
| "If you call a tool and receive its output, use the tool output as the main source for your answer. " | |
| "You may analyze, summarize, or combine tool outputs if needed to answer the question, but do not ignore tool outputs or say you cannot access files or images. " | |
| "Do not include phrases like 'Final answer', 'The answer is', or any commentary. Output only the answer string. " | |
| "If a question involves a file, audio, or image, use the appropriate tool to access or process the file. Do not say you cannot access files—always attempt a tool call first. " | |
| "If a tool result contains the answer, output the answer immediately. Do not make additional tool calls if the answer is already present in the tool result. " | |
| ) | |
| chat = ChatOpenAI( | |
| model="o3", # GPT‑o3 model | |
| temperature=1, | |
| openai_api_key=os.getenv("OPENAI_API_KEY"), | |
| ) | |
| # Bind tools with the LLM | |
| chat_with_tools = chat.bind_tools(TOOLS) | |
| # Agent state: tracks conversation history | |
| class AgentState(TypedDict): | |
| messages: Annotated[list[AnyMessage], add_messages] | |
| # Assistant node: single chat invocation, LLM always decides | |
| def assistant(state: AgentState) -> dict[str, list[AnyMessage]]: | |
| logger.info("[Agent] Thinking...") | |
| # Only log message types and contents, skip system prompt | |
| logger.info("[Agent] Messages so far:") | |
| for m in state['messages']: | |
| # Skip system prompt | |
| if hasattr(m, 'content') and isinstance(m.content, str) and m.content.startswith("You are a helpful assistant."): | |
| continue | |
| logger.info(f"{type(m).__name__}: {getattr(m, 'content', str(m))}") | |
| logger.info("-" * 40) | |
| # Track tool call attempts in state | |
| if 'tool_call_attempts' not in state: | |
| state['tool_call_attempts'] = 0 | |
| # If tool call limit reached, inject special ToolMessage and force answer | |
| if state['tool_call_attempts'] >= 2: | |
| logger.info("[Agent] Tool call limit reached. Injecting tool limit message.") | |
| state['messages'].append(ToolMessage(content= "YOU CAN NO LONGER MAKE ANY TOOL CALL, PLEASE ANSWER WITH THE CONTEXT YOU HAVE, OR CLEARLY STATE THAT YOU DO NOT HAVE ENOUGH DATA." | |
| , tool_call_id="tool_limit")) | |
| next_msg = chat_with_tools.invoke(state["messages"]) | |
| logger.info(f"[Agent] LLM response: {next_msg.content}") | |
| if getattr(next_msg, "tool_calls", None): | |
| for tc in next_msg.tool_calls: | |
| logger.info(f"[Tool Call] {tc['name']} | Args: {tc['args']}") | |
| return {"messages": state["messages"] + [next_msg], "tool_call_attempts": state['tool_call_attempts']} | |
| next_msg = chat_with_tools.invoke(state["messages"]) | |
| logger.info(f"[Agent] LLM response: {next_msg.content}") | |
| # If the LLM wants to call a tool, increment the counter | |
| if getattr(next_msg, "tool_calls", None): | |
| state['tool_call_attempts'] += 1 | |
| for tc in next_msg.tool_calls: | |
| logger.info(f"[Tool Call] {tc['name']} | Args: {tc['args']}") | |
| return {"messages": state["messages"] + [next_msg], "tool_call_attempts": state['tool_call_attempts']} | |
| # Condition: check if the assistant wants to use a tool again | |
| def needs_tool(state: AgentState) -> str: | |
| last = state["messages"][-1] | |
| # If the LLM called a tool, we route to the tool node | |
| if getattr(last, "tool_calls", None): | |
| return "tools" | |
| # Else, stop at END | |
| return "end" | |
| # Build the graph | |
| def build_langgraph(): | |
| builder = StateGraph(AgentState) | |
| builder.add_node("assistant", assistant) | |
| builder.add_node("tools", ToolNode(TOOLS)) | |
| builder.set_entry_point("assistant") | |
| builder.add_conditional_edges( | |
| "assistant", | |
| needs_tool, | |
| {"tools": "tools", "end": END} | |
| ) | |
| builder.add_edge("tools", "assistant") | |
| return builder.compile() | |
| # High-level solve function with logging and token counting | |
| def solve(question: str) -> str: | |
| logger.info(f"[User] {question}") | |
| graph = build_langgraph() | |
| state = {"messages": [SystemMessage(content=system_prompt), HumanMessage(content=question)]} | |
| step = 0 | |
| all_messages = list(state["messages"]) | |
| google_search_calls = 0 | |
| MAX_GOOGLE_SEARCH_CALLS = 10 | |
| tool_call_counts = {} | |
| GIVE_UP_THRESHOLD = 5 | |
| fallback_answer = "Unable to determine from available data." | |
| recursion_fallback = "Unable to find the answer with the given data." | |
| try: | |
| while True: | |
| step += 1 | |
| logger.info(f"\n========== Step {step} ==========") | |
| # Run one step of the graph with recursion_limit set to 25 | |
| result = graph.invoke(state, {"recursion_limit": 13}) | |
| new_msgs = result["messages"][len(state["messages"]):] | |
| for msg in new_msgs: | |
| if hasattr(msg, "tool_calls") and msg.tool_calls: | |
| for tool_call in msg.tool_calls: | |
| logger.info(f"[Tool Call] {tool_call['name']} | Args: {tool_call.get('args', tool_call.get('function', {}).get('arguments', ''))}") | |
| if isinstance(msg, ToolMessage): | |
| logger.info(f"[Tool Result] {msg.content}") | |
| if isinstance(msg, AIMessage): | |
| logger.info(f"[Agent Thinking] {msg.content}") | |
| # Intercept tool calls and block google_search_tool after limit | |
| if hasattr(msg, "tool_call_id") and hasattr(msg, "name") and msg.name == "google_search_tool": | |
| google_search_calls += 1 | |
| if google_search_calls > MAX_GOOGLE_SEARCH_CALLS: | |
| refusal = ToolMessage( | |
| content="Google search tool call refused: limit of 10 calls per question reached.", | |
| tool_call_id=msg.tool_call_id | |
| ) | |
| result["messages"][result["messages"].index(msg)] = refusal | |
| logger.info("[ToolMessage] Google search tool call refused: limit reached.") | |
| if hasattr(msg, "name") and hasattr(msg, "tool_call_id"): | |
| tool_args = "" | |
| if hasattr(msg, "additional_kwargs") and msg.additional_kwargs and "tool_calls" in msg.additional_kwargs: | |
| tool_calls = msg.additional_kwargs["tool_calls"] | |
| if tool_calls and isinstance(tool_calls, list): | |
| tool_args = tool_calls[0].get("function", {}).get("arguments", "") | |
| tool_key = (msg.name, tool_args.strip().lower()) | |
| tool_call_counts[tool_key] = tool_call_counts.get(tool_key, 0) + 1 | |
| if tool_call_counts[tool_key] > GIVE_UP_THRESHOLD: | |
| logger.info(f"[Agent] Give up condition met for tool {msg.name} with similar arguments: {tool_args}") | |
| return fallback_answer | |
| all_messages.extend(new_msgs) | |
| state["messages"] = result["messages"] | |
| # Next action logging | |
| last_msg = state["messages"][-1] | |
| if getattr(last_msg, "tool_calls", None): | |
| logger.info("[Next Action] Agent will call a tool.") | |
| else: | |
| logger.info("[Next Action] Agent will answer.") | |
| if not getattr(last_msg, "tool_calls", None): | |
| break | |
| logger.info(f"[Agent] Final answer: {state['messages'][-1].content}") | |
| token_count = count_tokens(all_messages) | |
| if token_count >= 0: | |
| logger.info(f"[Stats] Total tokens used: {token_count}") | |
| return state["messages"][-1].content | |
| except Exception as e: | |
| import langgraph.errors | |
| if isinstance(e, langgraph.errors.GraphRecursionError): | |
| logger.info("[Agent] Recursion limit reached, returning fallback answer.") | |
| return recursion_fallback | |
| else: | |
| logger.error(f"[Agent] Unexpected error: {e}") | |
| raise | |
| def download_file(url, dest_path): | |
| response = requests.get(url, stream=True) | |
| response.raise_for_status() | |
| with open(dest_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| print(f"Downloaded {url} to {dest_path}") |