Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| from langgraph.graph import START, StateGraph, MessagesState | |
| from langgraph.prebuilt import tools_condition | |
| from langgraph.prebuilt import ToolNode | |
| from langchain_community.tools import DuckDuckGoSearchResults | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from langchain_community.document_loaders import WikipediaLoader | |
| from langchain_community.document_loaders import ArxivLoader | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| from langchain_core.tools import tool | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| #load_dotenv() | |
| google_api_key = os.environ["GOOGLE_API_KEY"] | |
| hf_api_key = os.environ["HF_TOKEN"] | |
| def add(a: int, b: int) -> int: | |
| """ Add a and b """ | |
| return a + b | |
| def subtract(a: int,b: int) -> int: | |
| """ Subract b from a """ | |
| return a - b | |
| def multiply(a: int,b: int) -> int: | |
| """ Multiply a and b """ | |
| return a * b | |
| def divide(a: int,b: int) -> float: | |
| """ Divide a by b """ | |
| if b == 0: | |
| raise ValueError("Can't divide by 0.") | |
| return a/b | |
| def web_search(query: str) -> str: | |
| """ Search for a query on web and return best result.""" | |
| search = DuckDuckGoSearchResults(num_results=1) | |
| results = search.invoke(input=query) | |
| '''formatted_results = "\n\n-----\n\n".join( | |
| [ | |
| #f'<Result: source = "{result.metadata["source"]}", page = "{result.metadata.get("page","")}">\n{result.page_content}\n </Result>' | |
| f'<Result: source = "{result.get("url", "")}", page = "{result.get("title","")}">\n{result.get("content","")}\n </Result>' | |
| for result in results | |
| ] | |
| )''' | |
| return {"web_results" : results} | |
| '''@tool | |
| def web_search(query: str) -> str: | |
| """ Search for a query on web and return best 2 result.""" | |
| search_results = TavilySearchResults(max_results = 2).invoke(input=query) | |
| formatted_search_results = "\n\n-----\n\n".join( | |
| [ | |
| #f'<Result: source = "{result.metadata["source"]}", page = "{result.metadata.get("page","")}">\n{result.page_content}\n </Result>' | |
| f'<Result: source = "{result.get("url", "")}", page = "{result.get("title","")}">\n{result.get("content","")}\n </Result>' | |
| for result in search_results | |
| ] | |
| ) | |
| return {"web_results" : formatted_search_results}''' | |
| def wikipedia_search(query: str) -> str: | |
| """ Search for a query on wikipedia and return best result.""" | |
| loader = WikipediaLoader(query=query, load_max_docs=1) | |
| search_results = loader.load() # Now, just call load() without arguments | |
| formatted_search_results = "\n\n-----\n\n".join( | |
| [ | |
| # Each 'result' here is a Document object. | |
| # Access metadata through .metadata and content through .page_content | |
| f'<Result: source = "{result.metadata.get("source", "")}", page = "{result.metadata.get("title","")}">\n{result.page_content}\n </Result>' | |
| for result in search_results | |
| ] | |
| ) | |
| return {"Wikipedia_results" : formatted_search_results} | |
| def arxiv_search(query: str) -> str: | |
| """ Search for a query on arxiv and return best result.""" | |
| # Similar to WikipediaLoader, query and load_max_docs are passed during initialization | |
| loader = ArxivLoader(query=query, load_max_docs=1) | |
| search_results = loader.load() # Call load() without arguments | |
| formatted_search_results = "\n\n-----\n\n".join( | |
| [ | |
| f'<Result: source = "{result.metadata.get("source", "")}", page = "{result.metadata.get("title","")}">\n{result.page_content}\n </Result>' | |
| for result in search_results | |
| ] | |
| ) | |
| return {"arxiv_results" : formatted_search_results} | |
| system_prompt = """You are a general AI assistant. I will ask you a question. Use your tools and think step by step to report your thoughts, and finish your answer with the following template: | |
| FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. | |
| If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. | |
| If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. | |
| If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.""" | |
| #Using your tools to 추가하니 툴컬링 하게됨 | |
| system_message = SystemMessage(content=system_prompt) | |
| tools = [ | |
| add,subtract,multiply,divide,web_search,wikipedia_search,arxiv_search | |
| ] | |
| def build_graph(provider: str = "google"): | |
| #if provider == "google": | |
| # Google Gemini | |
| llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0,api_key=google_api_key) | |
| # Bind tools to LLM | |
| llm_with_tools = llm.bind_tools(tools) | |
| def assistant(state: MessagesState): | |
| """ Use the tools to answer the query. you have add,subtract,multiply,divide,web_search,wikipedia_search,arxiv_search tools.""" | |
| response = llm_with_tools.invoke([system_message]+state["messages"]) | |
| time.sleep(4) # 무료 티어의 한계 | |
| return {"messages": state["messages"] + [response]} | |
| builder = StateGraph(MessagesState) | |
| builder.add_node("assistant", assistant) | |
| builder.add_node("tools", ToolNode(tools)) | |
| builder.add_edge(START, "assistant") | |
| builder.add_conditional_edges( | |
| "assistant", | |
| tools_condition | |
| ) | |
| builder.add_edge("tools", "assistant") | |
| return builder.compile() | |
| # test | |
| if __name__ == "__main__": | |
| question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?" | |
| # Build the graph | |
| graph = build_graph(provider="google") | |
| # Run the graph | |
| messages = [HumanMessage(content=question)] | |
| messages = graph.invoke({"messages": messages}) | |
| for m in messages["messages"]: | |
| m.pretty_print() | |