File size: 5,994 Bytes
d4fe024
94eedaa
 
a85a581
b96b4c9
 
 
 
63afde9
 
 
b96b4c9
 
 
 
 
 
7c0b69a
b96b4c9
 
 
 
d4fe024
17fe068
b96b4c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55b28c9
b96b4c9
55b28c9
 
63afde9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b667284
b96b4c9
 
 
 
 
 
 
 
63afde9
b96b4c9
 
 
55b28c9
b96b4c9
55b28c9
b96b4c9
 
 
 
 
 
 
 
 
 
 
 
 
 
55b28c9
b96b4c9
 
55b28c9
b96b4c9
 
 
 
 
 
 
 
 
 
 
 
aaa8150
b96b4c9
 
 
0e3134d
b96b4c9
 
 
 
e1322e0
b96b4c9
 
 
c0f4e2d
7c7b8a8
5bd8e40
7c7b8a8
4003c5d
 
 
c662dbe
 
 
7c0b69a
c662dbe
 
 
 
 
 
 
 
 
 
 
 
 
 
b96b4c9
 
 
 
 
c0f4e2d
b96b4c9
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os
import time


from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition
from langgraph.prebuilt import ToolNode

from langchain_community.tools import DuckDuckGoSearchResults


from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader
from langchain_community.document_loaders import ArxivLoader

from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool


from langchain_google_genai import ChatGoogleGenerativeAI

#load_dotenv()
google_api_key = os.environ["GOOGLE_API_KEY"]
hf_api_key = os.environ["HF_TOKEN"]

@tool
def add(a: int, b: int) -> int:
  """ Add a and b """
  return a + b
@tool
def subtract(a: int,b: int) -> int:
  """ Subract b from a """
  return a - b
@tool
def multiply(a: int,b: int) -> int:
  """ Multiply a and b """
  return a * b
@tool
def divide(a: int,b: int) -> float:
  """ Divide a by b """
  if b == 0:
    raise ValueError("Can't divide by 0.")
  return a/b

@tool
def web_search(query: str) -> str:
  """ Search for a query on web and return best result."""

  search = DuckDuckGoSearchResults(num_results=1)
  results = search.invoke(input=query)

  '''formatted_results = "\n\n-----\n\n".join(
      [
          #f'<Result: source = "{result.metadata["source"]}", page = "{result.metadata.get("page","")}">\n{result.page_content}\n </Result>'
          f'<Result: source = "{result.get("url", "")}", page = "{result.get("title","")}">\n{result.get("content","")}\n </Result>'
          for result in results
      ]
  )'''
  return {"web_results" : results}



    
'''@tool
def web_search(query: str) -> str:
  """ Search for a query on web and return best 2 result."""

  search_results = TavilySearchResults(max_results = 2).invoke(input=query)

  formatted_search_results = "\n\n-----\n\n".join(
      [
          #f'<Result: source = "{result.metadata["source"]}", page = "{result.metadata.get("page","")}">\n{result.page_content}\n </Result>'
          f'<Result: source = "{result.get("url", "")}", page = "{result.get("title","")}">\n{result.get("content","")}\n </Result>'
          for result in search_results
      ]
  )
  return {"web_results" : formatted_search_results}'''

@tool
def wikipedia_search(query: str) -> str:
  """ Search for a query on wikipedia and return best result."""

  loader = WikipediaLoader(query=query, load_max_docs=1)
  search_results = loader.load() # Now, just call load() without arguments

  formatted_search_results = "\n\n-----\n\n".join(
      [
          # Each 'result' here is a Document object.
          # Access metadata through .metadata and content through .page_content
          f'<Result: source = "{result.metadata.get("source", "")}", page = "{result.metadata.get("title","")}">\n{result.page_content}\n </Result>'
          for result in search_results
      ]
  )
  return {"Wikipedia_results" : formatted_search_results}

@tool
def arxiv_search(query: str) -> str:
  """ Search for a query on arxiv and return best result."""

  # Similar to WikipediaLoader, query and load_max_docs are passed during initialization
  loader = ArxivLoader(query=query, load_max_docs=1)
  search_results = loader.load() # Call load() without arguments

  formatted_search_results = "\n\n-----\n\n".join(
      [
          f'<Result: source = "{result.metadata.get("source", "")}", page = "{result.metadata.get("title","")}">\n{result.page_content}\n </Result>'
          for result in search_results
      ]
  )
  return {"arxiv_results" : formatted_search_results}



system_prompt = """You are a general AI assistant. I will ask you a question. Use your tools and think step by step to report your thoughts, and finish your answer with the following template:
FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."""
#Using your tools to 추가하니 툴컬링 하게됨
system_message = SystemMessage(content=system_prompt)

tools = [
    add,subtract,multiply,divide,web_search,wikipedia_search,arxiv_search
]


def build_graph(provider: str = "google"):
    #if provider == "google":
        # Google Gemini
    llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0,api_key=google_api_key)
    # Bind tools to LLM
    llm_with_tools = llm.bind_tools(tools)
    
    def assistant(state: MessagesState):
        """ Use the tools to answer the query. you have add,subtract,multiply,divide,web_search,wikipedia_search,arxiv_search tools."""
        response = llm_with_tools.invoke([system_message]+state["messages"])
        time.sleep(4) # 무료 티어의 한계
        return {"messages": state["messages"] + [response]}


    builder = StateGraph(MessagesState)
    builder.add_node("assistant", assistant)
    builder.add_node("tools", ToolNode(tools))
    builder.add_edge(START, "assistant")
    builder.add_conditional_edges(
        "assistant",
        tools_condition
    )
    builder.add_edge("tools", "assistant")

    return builder.compile()

# test
if __name__ == "__main__":
    question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
    # Build the graph
    graph = build_graph(provider="google")
    # Run the graph
    messages = [HumanMessage(content=question)]
    messages = graph.invoke({"messages": messages})
    for m in messages["messages"]:
        m.pretty_print()