Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| from typing import Annotated, Optional, TypedDict, List | |
| from dotenv import load_dotenv | |
| from langgraph.graph.message import add_messages | |
| from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage | |
| from langchain.chat_models import init_chat_model | |
| from langgraph.graph import StateGraph, MessagesState, START, END | |
| from langgraph.prebuilt import ToolNode | |
| import requests | |
| from langchain_community.document_loaders import WikipediaLoader | |
| from langchain_community.document_loaders import WebBaseLoader | |
| from langchain_core.tools import tool | |
| from tool.math import add, divide, multiply, subtract, modulus | |
| from tool.youtube import youtube_transcript | |
| load_dotenv() | |
| llm = init_chat_model( | |
| model="gpt-4o", | |
| model_provider="openai", | |
| max_retries=2, | |
| openai_api_base=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), | |
| openai_api_key=os.getenv("OPENAI_API_KEY"), | |
| openai_proxy=os.getenv("OPENAI_PROXY"), | |
| ) | |
| def analyze_image_by_url(image_url: str, prompt: str) -> str: | |
| """Using VL model to analyze the image in image_url using the prompt, and return the answer. | |
| Args: | |
| image_url: The url of the image to analyze | |
| prompt: The prompt to use to analyze the image | |
| Returns: | |
| The answer to the prompt | |
| """ | |
| if image_url is None: | |
| return "" | |
| response = llm.invoke([{ | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": prompt}, | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": image_url | |
| } | |
| } | |
| ] | |
| }]) | |
| print(f"Response: {response.content}") | |
| return response.content | |
| def read_file_by_path(file_path: str) -> str: | |
| """Read the file in file_path and return the content.""" | |
| print(f"Reading file: {file_path}") | |
| if file_path is None: | |
| return "" | |
| with open(file_path, "r") as f: | |
| return f.read() | |
| def read_file_by_url(file_url: str) -> str: | |
| """Read the file in file_url and return the content. | |
| Args: | |
| file_url: The url of the file to read | |
| Returns: | |
| The raw content of the file | |
| """ | |
| print(f"Reading file: {file_url}") | |
| if file_url is None: | |
| return "" | |
| response = requests.get(file_url) | |
| return response.content | |
| def load_webpage_from_url(url: str) -> str: | |
| """Load the webpage from the given url and return the content. | |
| Args: | |
| url: The url of the webpage to load | |
| Returns: | |
| The content of the webpage | |
| """ | |
| print(f"Loading webpage from: {url}") | |
| return WebBaseLoader(url).load() | |
| def load_wikipedia(query: str) -> str: | |
| """Load Wikipedia for the given query and return the content. | |
| Args: | |
| query: The query to search Wikipedia for | |
| Returns: | |
| The content of the Wikipedia page | |
| """ | |
| print(f"Loading Wikipedia for: {query}") | |
| return WikipediaLoader(query=query, load_max_docs=1).load() | |
| def search_google(query: str) -> str: | |
| """Search Google for the given query and return the result. | |
| Args: | |
| query: The query to search Google for | |
| Returns: | |
| The result of the Google search | |
| """ | |
| print(f"Searching Google for: {query}") | |
| url = "https://google.serper.dev/search" | |
| payload = json.dumps({ | |
| "q": query | |
| }) | |
| headers = { | |
| 'X-API-KEY': os.getenv("SERPER_API_KEY"), | |
| 'Content-Type': 'application/json' | |
| } | |
| response = requests.request("POST", url, headers=headers, data=payload) | |
| print(f"Google search result for: {query}") | |
| print(response.text) | |
| return response.text | |
| tools = [ | |
| youtube_transcript, | |
| analyze_image_by_url, | |
| read_file_by_path, | |
| read_file_by_url, | |
| load_webpage_from_url, | |
| load_wikipedia, | |
| search_google, | |
| multiply, | |
| add, | |
| subtract, | |
| divide, | |
| modulus | |
| ] | |
| llm_with_tools = llm.bind_tools(tools) | |
| class State(TypedDict): | |
| local_file_path: Optional[str] | |
| file_url: Optional[str] | |
| messages: Annotated[list[AnyMessage], add_messages] | |
| answer: str | |
| def should_continue(state: State): | |
| messages = state["messages"] | |
| last_message = messages[-1] | |
| if last_message.tool_calls: | |
| return "tools" | |
| return "format_answer" | |
| def format_answer(state: State): | |
| system_message_content = "You are a AI assistant to extract the answer from the user's answer. \ | |
| The user's answer should be in the following format: \ | |
| FINAL ANSWER: [YOUR FINAL ANSWER]. \ | |
| Your need to extract and only return the answer. If you don't find the answer, output 'N/A' \ | |
| Remove '.' from the end of the answer." | |
| system_message = SystemMessage(content=system_message_content) | |
| messages = [system_message] + [state["messages"][-1]] | |
| answer = llm.invoke(messages) | |
| return {"answer": answer.content} | |
| def agent(state: State): | |
| system_message_content = "You are a general AI assistant. I will ask you a question. \ | |
| Report your thoughts, and finish your answer with the following template: \ | |
| FINAL ANSWER: [YOUR FINAL ANSWER]. \ | |
| YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. \ | |
| If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. \ | |
| If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. \ | |
| If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. \ | |
| Your answer should only start with 'FINAL ANSWER: ', then follows with the answer. " | |
| if state["local_file_path"]: | |
| system_message_content += f"\nYou can only read files I provide you. You are given a file path related to the question: {state['local_file_path']}, and the online url related to the same file: {state['file_url']}" | |
| system_message = SystemMessage(content=system_message_content) | |
| messages = [system_message] + state["messages"] | |
| return {"messages": [llm_with_tools.invoke(messages)]} | |
| class Agent: | |
| def __init__(self): | |
| print("BasicAgent initialized.") | |
| tool_node = ToolNode(tools) | |
| graph_builder = StateGraph(State) | |
| graph_builder.add_node("agent", agent) | |
| graph_builder.add_node("tools", tool_node) | |
| graph_builder.add_node("format_answer", format_answer) | |
| graph_builder.add_edge(START, "agent") | |
| graph_builder.add_conditional_edges("agent", should_continue, ["tools", "format_answer"]) | |
| graph_builder.add_edge("tools", "agent") | |
| graph_builder.add_edge("format_answer", END) | |
| self.graph = graph_builder.compile() | |
| try: | |
| # Save graph visualization as PNG file | |
| graph_viz = self.graph.get_graph() | |
| with open("graph.png", "wb") as f: | |
| f.write(graph_viz.draw_mermaid_png()) | |
| print("Graph visualization saved as 'graph.png'") | |
| except Exception as e: | |
| # Drawing requires graphviz to be installed | |
| print(f"Could not save graph visualization: {str(e)}") | |
| pass | |
| def __call__(self, question: str, local_file_path: str|None, file_url: str|None) -> str: | |
| result = self.graph.invoke({"local_file_path": local_file_path, "file_url": file_url, "messages": [HumanMessage(content=question)]}) | |
| return result["answer"] | |