| import os |
| from textwrap import dedent |
| from typing import TypedDict, Annotated, Optional, Any, Callable, Sequence, Union |
|
|
| from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage |
| from langchain_core.tools import BaseTool |
| from langchain_openai import ChatOpenAI |
| from langchain_tavily import TavilySearch |
| from langgraph.constants import START |
| from langgraph.errors import GraphRecursionError |
| from langgraph.graph import add_messages, StateGraph |
| from langgraph.prebuilt import ToolNode, tools_condition |
| from langgraph.pregel import PregelProtocol |
| from loguru import logger |
| from pydantic import SecretStr |
|
|
| from tools.excel_to_text import excel_to_text |
| from tools.execute_python_code_from_file import execute_python_code_from_file |
| from tools.add_integers import add_integers |
| from tools.produce_classifier import produce_classifier |
| from tools.sort_words_alphabetically import sort_words_alphabetically |
| from tools.transcribe_audio import transcribe_audio |
| from tools.web_page_information_extractor import web_page_information_extractor |
| from tools.wikipedia_search import wikipedia_search |
| from tools.youtube_transcript import youtube_transcript |
|
|
|
|
| class AgentState(TypedDict): |
| messages: Annotated[list[AnyMessage], add_messages] |
|
|
|
|
| class ShrewdAgent: |
| message_system = dedent(""" |
| You are a general AI assistant equipped with a suite of external tools. Your task is to |
| answer the following question as accurately and helpfully as possible by using the tools |
| provided. Do not write or execute code yourself. For any operation requiring computation, |
| data retrieval, or external access, explicitly invoke the appropriate tool. |
| |
| Follow these guidelines: |
| - Clearly explain your reasoning step by step. |
| - Justify your choice of tool(s) at each step. |
| - If multiple interpretations are possible, outline them and explain your reasoning for selecting one. |
| - If the answer requires external data or inference, retrieve or deduce it via the available tools. |
| Important: Your final output MUST be only a number or a word with no additional text or explanation, |
| unless the response format is explicitly specified in the question. Do not include reasoning, |
| commentary, or any other content beyond the requested answer.""") |
|
|
| def __init__(self): |
| tools = [] |
| |
| print("Loading: TavilySearch") |
| tools.append(TavilySearch()) |
| |
| print("Loading: wikipedia_search") |
| tools.append(wikipedia_search) |
| |
| print("Loading: web_page_information_extractor") |
| tools.append(web_page_information_extractor) |
| |
| print("Loading: youtube_transcript") |
| tools.append(youtube_transcript) |
| |
| print("Loading: produce_classifier") |
| tools.append(produce_classifier) |
| |
| print("Loading: sort_words_alphabetically") |
| tools.append(sort_words_alphabetically) |
| |
| print("Loading: excel_to_text") |
| tools.append(excel_to_text) |
| |
| print("Loading: execute_python_code_from_file") |
| tools.append(execute_python_code_from_file) |
| |
| print("Loading: add_integers") |
| tools.append(add_integers) |
| |
| print("Loading: transcribe_audio") |
| tools.append(transcribe_audio) |
| |
| self.tools = tools |
| |
| print("Binding tools to LLM...") |
| self.llm = ChatOpenAI( |
| model="gpt-4o", |
| temperature=0, |
| api_key=SecretStr(os.environ['OPENAI_API_KEY']) |
| ).bind_tools(self.tools) |
| |
| print("Compiling agent graph...") |
| def assistant_node(state: AgentState): |
| return { |
| "messages": [self.llm.invoke(state["messages"])], |
| } |
| |
| self.agent = _build_state_graph(AgentState, assistant_node, self.tools) |
| logger.info(f"Agent initialized with tools: {[tool.name for tool in self.tools]}") |
| logger.debug(f"system message:\n{self.message_system}") |
|
|
|
|
| def __call__(self, question: str) -> str: |
| logger.info(f"Agent received question:\n{question}") |
| accumulated_response = [] |
| try: |
| for chunk in self.agent.stream( |
| {"messages": [ |
| SystemMessage(self.message_system), |
| HumanMessage(question, ) |
| ]}, |
| {"recursion_limit": 18}, |
| ): |
| assistant = chunk.get("assistant") |
| if assistant: |
| logger.debug(f"\n{assistant.get('messages')[0].pretty_repr()}") |
| tools = chunk.get("tools") |
| if tools: |
| logger.debug(f"\n{tools.get('messages')[0].pretty_repr()}") |
| accumulated_response.append(chunk) |
|
|
| except GraphRecursionError as e: |
| logger.error(f"GraphRecursionError: {e}") |
|
|
| final_answer = "I couldn't find the answer" |
| if accumulated_response and accumulated_response[-1].get("assistant"): |
| final_answer = accumulated_response[-1]["assistant"]['messages'][-1].content |
| logger.info(f"Agent returning answer: {final_answer}") |
| return final_answer |
|
|
|
|
| def _build_state_graph( |
| state_schema: Optional[type[Any]], |
| assistant: Callable, |
| tools: Sequence[Union[BaseTool, Callable]]) -> PregelProtocol: |
|
|
| return (StateGraph(state_schema) |
| .add_node("assistant", assistant) |
| .add_node("tools", ToolNode(tools)) |
| .add_edge(START, "assistant") |
| .add_conditional_edges("assistant", tools_condition) |
| .add_edge("tools", "assistant") |
| .compile() |
| ) |