yplam's picture
add more tools
7c21d30
import json
import os
from typing import Annotated, Optional, TypedDict, List
from dotenv import load_dotenv
from langgraph.graph.message import add_messages
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage
from langchain.chat_models import init_chat_model
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.prebuilt import ToolNode
import requests
from langchain_community.document_loaders import WikipediaLoader
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.tools import tool
from tool.math import add, divide, multiply, subtract, modulus
from tool.youtube import youtube_transcript
load_dotenv()
llm = init_chat_model(
model="gpt-4o",
model_provider="openai",
max_retries=2,
openai_api_base=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
openai_api_key=os.getenv("OPENAI_API_KEY"),
openai_proxy=os.getenv("OPENAI_PROXY"),
)
@tool
def analyze_image_by_url(image_url: str, prompt: str) -> str:
"""Using VL model to analyze the image in image_url using the prompt, and return the answer.
Args:
image_url: The url of the image to analyze
prompt: The prompt to use to analyze the image
Returns:
The answer to the prompt
"""
if image_url is None:
return ""
response = llm.invoke([{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {
"url": image_url
}
}
]
}])
print(f"Response: {response.content}")
return response.content
def read_file_by_path(file_path: str) -> str:
"""Read the file in file_path and return the content."""
print(f"Reading file: {file_path}")
if file_path is None:
return ""
with open(file_path, "r") as f:
return f.read()
@tool
def read_file_by_url(file_url: str) -> str:
"""Read the file in file_url and return the content.
Args:
file_url: The url of the file to read
Returns:
The raw content of the file
"""
print(f"Reading file: {file_url}")
if file_url is None:
return ""
response = requests.get(file_url)
return response.content
@tool
def load_webpage_from_url(url: str) -> str:
"""Load the webpage from the given url and return the content.
Args:
url: The url of the webpage to load
Returns:
The content of the webpage
"""
print(f"Loading webpage from: {url}")
return WebBaseLoader(url).load()
@tool
def load_wikipedia(query: str) -> str:
"""Load Wikipedia for the given query and return the content.
Args:
query: The query to search Wikipedia for
Returns:
The content of the Wikipedia page
"""
print(f"Loading Wikipedia for: {query}")
return WikipediaLoader(query=query, load_max_docs=1).load()
@tool
def search_google(query: str) -> str:
"""Search Google for the given query and return the result.
Args:
query: The query to search Google for
Returns:
The result of the Google search
"""
print(f"Searching Google for: {query}")
url = "https://google.serper.dev/search"
payload = json.dumps({
"q": query
})
headers = {
'X-API-KEY': os.getenv("SERPER_API_KEY"),
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
print(f"Google search result for: {query}")
print(response.text)
return response.text
tools = [
youtube_transcript,
analyze_image_by_url,
read_file_by_path,
read_file_by_url,
load_webpage_from_url,
load_wikipedia,
search_google,
multiply,
add,
subtract,
divide,
modulus
]
llm_with_tools = llm.bind_tools(tools)
class State(TypedDict):
local_file_path: Optional[str]
file_url: Optional[str]
messages: Annotated[list[AnyMessage], add_messages]
answer: str
def should_continue(state: State):
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
return "tools"
return "format_answer"
def format_answer(state: State):
system_message_content = "You are a AI assistant to extract the answer from the user's answer. \
The user's answer should be in the following format: \
FINAL ANSWER: [YOUR FINAL ANSWER]. \
Your need to extract and only return the answer. If you don't find the answer, output 'N/A' \
Remove '.' from the end of the answer."
system_message = SystemMessage(content=system_message_content)
messages = [system_message] + [state["messages"][-1]]
answer = llm.invoke(messages)
return {"answer": answer.content}
def agent(state: State):
system_message_content = "You are a general AI assistant. I will ask you a question. \
Report your thoughts, and finish your answer with the following template: \
FINAL ANSWER: [YOUR FINAL ANSWER]. \
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. \
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. \
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. \
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. \
Your answer should only start with 'FINAL ANSWER: ', then follows with the answer. "
if state["local_file_path"]:
system_message_content += f"\nYou can only read files I provide you. You are given a file path related to the question: {state['local_file_path']}, and the online url related to the same file: {state['file_url']}"
system_message = SystemMessage(content=system_message_content)
messages = [system_message] + state["messages"]
return {"messages": [llm_with_tools.invoke(messages)]}
class Agent:
def __init__(self):
print("BasicAgent initialized.")
tool_node = ToolNode(tools)
graph_builder = StateGraph(State)
graph_builder.add_node("agent", agent)
graph_builder.add_node("tools", tool_node)
graph_builder.add_node("format_answer", format_answer)
graph_builder.add_edge(START, "agent")
graph_builder.add_conditional_edges("agent", should_continue, ["tools", "format_answer"])
graph_builder.add_edge("tools", "agent")
graph_builder.add_edge("format_answer", END)
self.graph = graph_builder.compile()
try:
# Save graph visualization as PNG file
graph_viz = self.graph.get_graph()
with open("graph.png", "wb") as f:
f.write(graph_viz.draw_mermaid_png())
print("Graph visualization saved as 'graph.png'")
except Exception as e:
# Drawing requires graphviz to be installed
print(f"Could not save graph visualization: {str(e)}")
pass
def __call__(self, question: str, local_file_path: str|None, file_url: str|None) -> str:
result = self.graph.invoke({"local_file_path": local_file_path, "file_url": file_url, "messages": [HumanMessage(content=question)]})
return result["answer"]