|
|
import os |
|
|
from typing import Any, List, TypedDict |
|
|
from langgraph.graph import StateGraph |
|
|
from langgraph.prebuilt import ToolNode, tools_condition |
|
|
from langchain_core.messages import HumanMessage, SystemMessage |
|
|
from langchain_core.rate_limiters import InMemoryRateLimiter |
|
|
from langchain_openai import ChatOpenAI |
|
|
from tools import ( |
|
|
describe_image_tool, |
|
|
parse_excel_tool, |
|
|
webpage_extraction_tool, |
|
|
brave_web_search, |
|
|
python_code_interpreter_tool, |
|
|
audio_file_transcriber, |
|
|
get_youtube_transcript |
|
|
) |
|
|
|
|
|
class AgentState(TypedDict): |
|
|
messages: List[Any] |
|
|
question: str |
|
|
file_path: str |
|
|
final_answer: str |
|
|
|
|
|
tools = [ |
|
|
describe_image_tool, |
|
|
parse_excel_tool, |
|
|
webpage_extraction_tool, |
|
|
brave_web_search, |
|
|
python_code_interpreter_tool, |
|
|
audio_file_transcriber, |
|
|
get_youtube_transcript |
|
|
] |
|
|
|
|
|
rate_limiter = InMemoryRateLimiter( |
|
|
requests_per_second=0.1, |
|
|
check_every_n_seconds=0.1, |
|
|
max_bucket_size=10, |
|
|
) |
|
|
|
|
|
class LangGraphAgent: |
|
|
def __init__(self, model_name: str = "gpt-4o",): |
|
|
self.llm = ChatOpenAI(model=model_name, max_tokens=2000, temperature=0, rate_limiter=rate_limiter) |
|
|
self.llm_with_tools = self.llm.bind_tools(tools) |
|
|
self.tool_node = ToolNode(tools) |
|
|
self.graph = self.create_graph().compile() |
|
|
|
|
|
def create_graph(self) -> StateGraph: |
|
|
"""Creates a state graph for the agent's workflow.""" |
|
|
|
|
|
graph = StateGraph(AgentState) |
|
|
graph.add_node("agent", self.agent_node) |
|
|
graph.add_node("tools", self.tool_node) |
|
|
|
|
|
graph.set_entry_point("agent") |
|
|
|
|
|
graph.add_conditional_edges("agent", tools_condition) |
|
|
graph.add_edge("tools", "agent") |
|
|
|
|
|
return graph |
|
|
|
|
|
def agent_node(self, state: AgentState): |
|
|
"""Creates a node for the agent that uses the model to respond to user queries.""" |
|
|
messages = state['messages'] |
|
|
|
|
|
if len(messages) == 1: |
|
|
system_prompt = ("You are a helpful assistant that can answer questions using various tools. " |
|
|
"You must answer the given question using as few words as possible, or the given format, if any." |
|
|
"If the answer is a number, you must return the number only, do not include symbols or commas." |
|
|
"If you need to search the web for information and aren't given a URL, always use a search tool before using a wepbage extraction tool so you always have a legit website." |
|
|
"If given a Python file, execute it with the code interpreter tool (riza_exec_python)") |
|
|
messages = [SystemMessage(system_prompt)] + messages |
|
|
|
|
|
MAX_HISTORY = 3 |
|
|
|
|
|
|
|
|
if len(messages) > MAX_HISTORY: |
|
|
messages = [m for m in messages if isinstance(m, SystemMessage)] + messages[-MAX_HISTORY:] |
|
|
|
|
|
while True: |
|
|
response = self.llm_with_tools.invoke(messages) |
|
|
messages.append(response) |
|
|
|
|
|
if not response.tool_calls: |
|
|
break |
|
|
|
|
|
tool_outputs = [] |
|
|
for call in response.tool_calls: |
|
|
tool_output = self.tool_node.invoke({"messages": [response], "tool_call": call}) |
|
|
tool_outputs.extend(tool_output["messages"]) |
|
|
|
|
|
messages.extend(tool_outputs) |
|
|
|
|
|
state["messages"] = messages |
|
|
return {"messages": messages} |
|
|
|
|
|
def run(self, question: str) -> str: |
|
|
state = AgentState(messages=[HumanMessage(content=question)], question=question, final_answer=None) |
|
|
result = self.graph.invoke(state) |
|
|
|
|
|
final_message = result["messages"][-1] |
|
|
if hasattr(final_message, 'content'): |
|
|
result['final_answer'] = final_message.content |
|
|
else: |
|
|
result['final_answer'] = str(final_message) |
|
|
return result['final_answer'] |