Spaces:
Runtime error
Runtime error
File size: 11,598 Bytes
5dfdf10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 |
# Libraries
from langchain_core.tools.base import BaseTool
from langgraph.graph import START, END, StateGraph
from typing import TypedDict, List, Optional, Literal, Union
from langchain_openai import ChatOpenAI
from langchain_core.messages import (
HumanMessage,
AIMessage,
SystemMessage,
BaseMessage,
ToolMessage,
)
from langchain_core.runnables.graph import MermaidDrawMethod
from langgraph.checkpoint.memory import MemorySaver
from asyncio import to_thread # Asyncronous processing
from dotenv import load_dotenv
import os, sys
import aiofiles
import sys
# from langgraph.prebuilt import ToolNode, tools_condition
from langfuse.callback import CallbackHandler
from playwright.async_api import async_playwright
from langchain_community.agent_toolkits.playwright.toolkit import (
PlayWrightBrowserToolkit,
)
import asyncio
tools_list = []
clean_browser = None
tools_by_name = {}
model_with_tools = None
_tools_initialized = False
_tools_lock = asyncio.Lock()
async def initialize_tools():
global tools_list, clean_browser, tools_by_name, model_with_tools, _tools_initialized
async with _tools_lock: # Esto se libera automáticamente
if _tools_initialized:
print("Tools already initialized")
return
print("Initializing tools")
try:
await import_local_modules()
tools_list, clean_browser = await setup_tools()
tools_by_name = {tool.name: tool for tool in tools_list}
model_with_tools = model.bind_tools(tools_list)
_tools_initialized = True
print("Initialized tools")
except Exception as e:
print(f"Error when initializing tools: {e}")
raise
async def import_local_modules() -> None:
src_path = await asyncio.to_thread(lambda: os.path.abspath("src"))
tools_path = await asyncio.to_thread(lambda: os.path.abspath("src/tools"))
sys.path.append(src_path)
sys.path.append(tools_path)
#asyncio.run(import_local_modules()) # DEPRECATED
sys.path.append(os.path.abspath("src"))
sys.path.append(os.path.abspath("src/tools"))
from tools import (
calculator,
search,
code_executor,
transcriber,
post_processing,
handle_text,
pandas_toolbox,
handle_json,
chess_tool,
handle_images,
)
# Load credentials
# var = "OPENAI_API_KEY"
# os.env[var] = os.getenv(var)
MAX_ITERATIONS = 7
ROOT_DIR = "/home/santiagoal/current-projects/chappie/"
AGENT_PROMPTS_DIR = os.path.join(ROOT_DIR, "prompts/agent/")
#SYS_MSG_PATH = os.path.join(AGENT_PROMPTS_DIR, "gaia_system_message.md")
load_dotenv()
use_studio = os.getenv("LANGGRAPH_STUDIO", "true").lower() == "true" # BUG
# LLM Model
async def set_sys_msg(prompt_path: str):
sys_msg = ""
async with aiofiles.open(prompt_path, "r") as f:
async for line in f:
sys_msg += line
return sys_msg
#SYSTEM_MESSAGE = asyncio.run(set_sys_msg(prompt_path=SYS_MSG_PATH))
model = ChatOpenAI(model="gpt-4o", temperature=0.5)
langfuse_callback_handler = CallbackHandler()
# Define tools to use
async def setup_tools():
# Cargar herramientas locales
old_tools = [
calculator.sum_,
calculator.subtract,
calculator.multiply,
calculator.divide,
search.web_search,
search.pull_youtube_video,
#search.fetch_online_pdf,
code_executor.code_executor,
transcriber.transcriber,
post_processing.sort_items_and_format,
handle_text.handle_text,
pandas_toolbox.read_df,
pandas_toolbox.query_df,
handle_json.handle_json,
chess_tool.grab_board_view,
chess_tool.extract_fen_position,
chess_tool.predict_next_best_move,
handle_images.detect_objects,
]
playwright = await async_playwright().start()
browser = await playwright.chromium.launch(headless=True)
async def cleanup_browser():
await browser.close()
await playwright.stop()
# Herramientas del navegador
web_toolkit = PlayWrightBrowserToolkit.from_browser(async_browser=browser)
web_tools = web_toolkit.get_tools()
# Optional: ajusta el timeout predeterminado de las tools Playwright
for tool in web_tools:
if hasattr(tool, "timeout"):
tool.timeout = 60000
all_tools = old_tools + web_tools
return all_tools, cleanup_browser
# tools_list, clean_browser = asyncio.run(setup_tools()) # DEPRECATED
# ToolNode(tools=tools_list, name="tools", )
#model_with_tools = model.bind_tools(tools_list) # DEPRECATED
# State
class TaskState(TypedDict):
messages: List[BaseMessage]
iteration: Optional[int]
# tools_by_name = {tool.name: tool for tool in tools}
# tools_by_name = {tool.name: tool for tool in tools_list} # Q: Does it work? # DEPRECATED
# Nodes
async def prepare_agent(state: TaskState) -> dict[str, list]:
try:
await initialize_tools()
except Exception as e:
print(f"Error initializing tools: {e}")
raise
messages = state.get("messages", [])
if not any(isinstance(m, SystemMessage) for m in messages):
sys_msg_path = os.path.join(AGENT_PROMPTS_DIR, "gaia_system_message.md")
sys_msg = await set_sys_msg(prompt_path=sys_msg_path)
messages.insert(0, SystemMessage(content=sys_msg))
return {"messages": messages, "iteration": 0}
async def tools_node(state: TaskState) -> dict[str, list]:
# result = [] # This line has been deleted cause we need to take in account chat history
result = state.get("messages", [])
for tool_call in state["messages"][-1].tool_calls:
tool = tools_by_name[tool_call["name"]]
observation = await tool.ainvoke(tool_call["args"])
result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
return {"messages": result}
async def agent(state: TaskState) -> dict:
"""
Agent node, contains the LLM Model used to process user requests.
Parameters
----------
state : TaskState
Information history, which flows through the agent graph
Returns:
dict: State update
Example:
>>> from langchain_core.messages import HumanMessage
>>> state = {"messages": [HumanMessage(content="What is LangGraph?")]}
>>> output = agent(state)
>>> isinstance(output["messages"][-1].content, str)
True
"""
# python
chat_history = state.get("messages", [])
iterations = state.get("iteration", 0)
model_response = await model_with_tools.ainvoke(input=chat_history)
# Ensure the response is valid before appending
if isinstance(model_response, AIMessage):
if model_response.tool_calls:
iterations += 1
chat_history.append(model_response)
else:
raise ValueError("Invalid model response format")
state_update = {"messages": chat_history, "iteration": iterations}
return state_update
# chat_history = state.get("messages", [])
# formatted_messages = [
# {
# "type": "human" if isinstance(msg, HumanMessage) else "ai",
# "content": msg.content,
# }
# for msg in chat_history
# ]
#
##formatted_messages = [
## SystemMessage(content="You are a helpful assistant.")
##] + formatted_messages
# formatted_messages = messages_from_dict(formatted_messages)
# response = model_with_tools.invoke(formatted_messages)
# current_iterations = state.get("iteration", 0)
# chat_history.append(response)
## Handle tool calls if they exist
# if hasattr(response, "tool_calls") and response.tool_calls:
# # This will trigger the tool execution in LangGraph
# return {
# "messages": chat_history + [response],
# "iteration": current_iterations + 1,
# }
## last_message = chat_history[-1]
## if isinstance(last_message, AIMessage) and hasattr(last_message, "tool_calls") and last_message.tool_calls:
## print(last_message.tool_calls)
# output = {"messages": chat_history}
# return output
# Conditional Edges
def should_use_tool(state: TaskState) -> Literal["tools", END]:
"""
Decides if using a tool is necessary to accomplish the task.
Looks for the last Chat message, if it is a tool call, redirects the state to the Tool Node. The state is rooted to end otherwise.
Parameters
----------
state : TaskState
Information history, which flows through the agent graph
Returns:
Literal["tools", END]: Next graph node to where the process should flow
Example:
>>> ('arg1', 'arg2')
'output'
"""
chat_history = state.get("messages", [])
last_message = chat_history[-1]
current_iterations = state.get("iteration", 0)
if current_iterations > MAX_ITERATIONS:
return END
elif isinstance(last_message, AIMessage) and last_message.tool_calls:
return "tools"
return END
# Build Graph
memory = MemorySaver() # Add persistence
builder = StateGraph(state_schema=TaskState)
builder.add_node("prepare_agent", prepare_agent)
builder.add_node("agent", agent)
builder.add_node("tools", tools_node)
builder.add_edge(START, "prepare_agent")
builder.add_edge("prepare_agent", "agent")
builder.add_conditional_edges(
source="agent", path=should_use_tool, path_map=["tools", END]
)
builder.add_edge("tools", "agent")
# builder.add_edge("agent", END)
# memory = MemorySaver()
graph = builder.compile() if use_studio else builder.compile(checkpointer=memory)
# Save graph
# graph_json = graph.to_json()
# with open("../../langgraph.json", "w") as f:
# f.write(graph_json)
# Save graph image
async def save_agent_architecture() -> None:
# TODO: the new images path is /home/santiagoal/current-projects/chappie/data/images
graph_image_bytes = await to_thread(lambda: graph.get_graph().draw_mermaid_png())
with open("./images/agent_architecture.png", "wb") as f:
f.write(graph_image_bytes)
# Test app
async def test_app() -> None:
"""
Test the Agent behavior, including complete conversation thread
"""
print("Testing App... \n")
query = str(input("Ingresa tu pregunta: "))
response = await graph.ainvoke(
input={"messages": [HumanMessage(content=query)]},
config={
"callbacks": [langfuse_callback_handler],
"configurable": {"thread_id": "1"},
},
)
# Show chat history
for msg in response["messages"]:
role = msg.type
content = msg.content
print(f"{role.upper()}: {content}\n")
return None
async def run_agent(
user_query: str = None,
print_response: bool = False,
clean_browser_fn=None,
) -> Union[str, float, int]:
try:
query = user_query if user_query else input("Pass your question: ")
response = await graph.ainvoke(
input={"messages": [HumanMessage(content=query)]},
config={
"callbacks": [langfuse_callback_handler],
"configurable": {"thread_id": "1"},
},
)
ai_answer = response.get("messages", [])[-1].content
if print_response:
print(ai_answer)
return ai_answer
finally:
if clean_browser_fn:
await clean_browser_fn()
if __name__ == "__main__":
if "dev" not in sys.argv:
asyncio.run(run_agent(print_response=True, clean_browser_fn=clean_browser))
# TODO: Use a Local class for general path management
# TODO: Modularize script |