CheeYung commited on
Commit
7d88664
·
1 Parent(s): 065bc2a

Update basic agent framework

Browse files
Files changed (6) hide show
  1. agent.py +100 -3
  2. app.py +11 -4
  3. prompt.txt +49 -0
  4. requirements.txt +14 -1
  5. sample.ipynb +43 -20
  6. supabase.sql +34 -26
agent.py CHANGED
@@ -1,13 +1,56 @@
1
  import os
 
2
  from typing import TypedDict, Annotated
 
3
  from langgraph.graph import MessagesState, START, StateGraph
4
  from langgraph.graph.message import add_messages
5
  from langgraph.prebuilt import tools_condition, ToolNode
6
  from langchain_core.messages import HumanMessage, SystemMessage, AnyMessage
7
  from langchain_core.tools import tool
 
 
8
  from langchain_community.tools.tavily_search import TavilySearchResults
 
 
 
9
  from langchain_google_genai import ChatGoogleGenerativeAI
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  @tool
12
  def add(a: int, b: int) -> int:
13
  """Add two numbers.
@@ -65,6 +108,49 @@ def modulus(a: int, b: int) -> int:
65
  """
66
  return a % b
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  # list of tools
69
  tools = [
70
  add,
@@ -72,7 +158,10 @@ tools = [
72
  multiply,
73
  power,
74
  divide,
75
- modulus
 
 
 
76
  ]
77
 
78
  # Generate the AgentState and Agent graph
@@ -90,25 +179,33 @@ def build_graph():
90
  return { "messages": [llm_with_tools.invoke(state['messages'])] }
91
 
92
  def retriever(state: AgentState):
93
- return None
 
 
 
 
94
 
95
  builder = StateGraph(AgentState)
96
 
97
  # Define nodes: these do the work
98
  builder.add_node("assistant", assistant)
 
99
  builder.add_node("tools", ToolNode(tools))
 
100
  builder.add_conditional_edges(
101
  "assistant",
102
  tools_condition
103
  )
104
  builder.add_edge("tools", "assistant")
 
105
 
106
  # Compile graph
107
  return builder.compile()
108
 
109
  # Test
110
  if __name__ == "__main__":
111
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
 
112
  graph = build_graph()
113
  messages = [HumanMessage(content=question)]
114
  messages = graph.invoke({ "messages": messages })
 
1
  import os
2
+ from dotenv import load_dotenv
3
  from typing import TypedDict, Annotated
4
+
5
  from langgraph.graph import MessagesState, START, StateGraph
6
  from langgraph.graph.message import add_messages
7
  from langgraph.prebuilt import tools_condition, ToolNode
8
  from langchain_core.messages import HumanMessage, SystemMessage, AnyMessage
9
  from langchain_core.tools import tool
10
+ from langchain.tools.retriever import create_retriever_tool
11
+
12
  from langchain_community.tools.tavily_search import TavilySearchResults
13
+ from langchain_community.document_loaders import WikipediaLoader
14
+ from langchain_community.document_loaders import ArxivLoader
15
+
16
  from langchain_google_genai import ChatGoogleGenerativeAI
17
 
18
+ from langchain_huggingface import HuggingFaceEmbeddings
19
+
20
+ from langchain_community.vectorstores import SupabaseVectorStore
21
+ from langchain.schema.document import Document
22
+ from supabase import create_client, Client
23
+
24
+ load_dotenv()
25
+
26
+ __embeddings = HuggingFaceEmbeddings(
27
+ model_name="sentence-transformers/all-mpnet-base-v2",
28
+ model_kwargs= { 'device': 'cpu' })
29
+
30
+ # connect to supabase
31
+ url: str = os.environ.get("SUPABASE_URL")
32
+ key: str = os.environ.get("SUPABASE_SECRET_KEY")
33
+ __supabase: Client = create_client(url, key)
34
+
35
+ # build retriever
36
+ vector_store = SupabaseVectorStore(
37
+ client=__supabase,
38
+ embedding=__embeddings,
39
+ table_name="documents",
40
+ query_name="match_documents",
41
+ )
42
+ question_retrieval_tool = create_retriever_tool(
43
+ vector_store.as_retriever(),
44
+ name="Question retriever",
45
+ description="Find similar questions in the vector database for the given question."
46
+ )
47
+
48
+ # load prompt message from txt file and convert to System Message
49
+ with open('prompt.txt', 'r', encoding='utf-8') as f:
50
+ sys_prompt = f.read()
51
+
52
+ __sys_msg = SystemMessage(content=sys_prompt)
53
+
54
  @tool
55
  def add(a: int, b: int) -> int:
56
  """Add two numbers.
 
108
  """
109
  return a % b
110
 
111
+ @tool
112
+ def wiki_search(query: str) -> str:
113
+ """Search Wikipedia for a query and return maximum 2 results.
114
+
115
+ Args:
116
+ query: The search query.
117
+ """
118
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
119
+ formatted_search_docs = "\n\n---\n\n".join([
120
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n\t{doc.page_content}\n<Document>'
121
+ for doc in search_docs
122
+ ])
123
+ return { "wiki_results": formatted_search_docs }
124
+
125
+ @tool
126
+ def web_search(query: str) -> str:
127
+ """Search Tavily for a query and return maximum 3 results.
128
+
129
+ Args:
130
+ query: The search query.
131
+ """
132
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
133
+ formatted_search_docs = "\n\n---\n\n".join([
134
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n\t{doc.page_content}\n<Document>'
135
+ for doc in search_docs
136
+ ])
137
+ return { "web_results": formatted_search_docs }
138
+
139
+ @tool
140
+ def arxiv_search(query: str) -> str:
141
+ """Search Arxiv for a query and return maximum 3 result.
142
+
143
+ Args:
144
+ query: The search query.
145
+ """
146
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
147
+ formatted_search_docs = "\n\n---\n\n".join([
148
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n\t{doc.page_content[:1000]}\n<Document>'
149
+ for doc in search_docs
150
+ ])
151
+ return { "arxiv_results": formatted_search_docs }
152
+
153
+
154
  # list of tools
155
  tools = [
156
  add,
 
158
  multiply,
159
  power,
160
  divide,
161
+ modulus,
162
+ wiki_search,
163
+ web_search,
164
+ arxiv_search
165
  ]
166
 
167
  # Generate the AgentState and Agent graph
 
179
  return { "messages": [llm_with_tools.invoke(state['messages'])] }
180
 
181
  def retriever(state: AgentState):
182
+ similar_question = vector_store.similarity_search(state['messages'][0].content)
183
+ example_msg = HumanMessage(
184
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
185
+ )
186
+ return { "messages": [__sys_msg] + state['messages'] + [example_msg] }
187
 
188
  builder = StateGraph(AgentState)
189
 
190
  # Define nodes: these do the work
191
  builder.add_node("assistant", assistant)
192
+ builder.add_node("retriever", retriever)
193
  builder.add_node("tools", ToolNode(tools))
194
+ builder.add_edge(START, "retriever")
195
  builder.add_conditional_edges(
196
  "assistant",
197
  tools_condition
198
  )
199
  builder.add_edge("tools", "assistant")
200
+ builder.add_edge("retriever", "assistant")
201
 
202
  # Compile graph
203
  return builder.compile()
204
 
205
  # Test
206
  if __name__ == "__main__":
207
+ # question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
208
+ question = "Data feed 100ms happened once. If 50second "
209
  graph = build_graph()
210
  messages = [HumanMessage(content=question)]
211
  messages = graph.invoke({ "messages": messages })
app.py CHANGED
@@ -3,6 +3,8 @@ import gradio as gr
3
  import requests
4
  import inspect
5
  import pandas as pd
 
 
6
 
7
  # (Keep Constants as is)
8
  # --- Constants ---
@@ -12,12 +14,17 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
12
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
13
  class BasicAgent:
14
  def __init__(self):
 
15
  print("BasicAgent initialized.")
16
  def __call__(self, question: str) -> str:
17
- print(f"Agent received question (first 50 chars): {question[:50]}...")
18
- fixed_answer = "This is a default answer."
19
- print(f"Agent returning fixed answer: {fixed_answer}")
20
- return fixed_answer
 
 
 
 
21
 
22
  def run_and_submit_all( profile: gr.OAuthProfile | None):
23
  """
 
3
  import requests
4
  import inspect
5
  import pandas as pd
6
+ from langchain_core.messages import HumanMessage
7
+ from agent import build_graph
8
 
9
  # (Keep Constants as is)
10
  # --- Constants ---
 
14
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
15
  class BasicAgent:
16
  def __init__(self):
17
+ self.graph = build_graph()
18
  print("BasicAgent initialized.")
19
  def __call__(self, question: str) -> str:
20
+ messages = [HumanMessage(content=question)]
21
+ messages = self.graph.invoke({ "messages": messages })
22
+ answer = messages['messages'][-1].content
23
+
24
+ # print(f"Agent received question (first 50 chars): {question[:50]}...")
25
+ # fixed_answer = "This is a default answer."
26
+ # print(f"Agent returning fixed answer: {fixed_answer}")
27
+ return answer[14:]
28
 
29
  def run_and_submit_all( profile: gr.OAuthProfile | None):
30
  """
prompt.txt ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are a helpful agent responsible for answering questions using a set of tools provided.
2
+ If the tool(s) not available, you can try to search and find the solution or information online.
3
+ You can also use your own knowledge to answer the question.
4
+ ==========================
5
+ Here is a few examples showing you how to answer the question step by step.
6
+
7
+ Question 1: In terms of geographical distance between capital cities, which 2 countries are the furthest from each other within the ASEAN bloc according to wikipedia? Answer using a comma separated list, ordering the countries by alphabetical order.
8
+ Steps:
9
+ 1. Search the web for "ASEAN bloc".
10
+ 2. Click the Wikipedia result for the ASEAN Free Trade Area.
11
+ 3. Scroll down to find the list of member states.
12
+ 4. Click into the Wikipedia pages for each member state, and note its capital.
13
+ 5. Search the web for the distance between the first two capitals. The results give travel distance, not geographic distance, which might affect the answer.
14
+ 6. Thinking it might be faster to judge the distance by looking at a map, search the web for "ASEAN bloc" and click into the images tab.
15
+ 7. View a map of the member countries. Since they're clustered together in an arrangement that's not very linear, it's difficult to judge distances by eye.
16
+ 8. Return to the Wikipedia page for each country. Click the GPS coordinates for each capital to get the coordinates in decimal notation.
17
+ 9. Place all these coordinates into a spreadsheet.
18
+ 10. Write formulas to calculate the distance between each capital.
19
+ 11. Write formula to get the largest distance value in the spreadsheet.
20
+ 12. Note which two capitals that value corresponds to: Jakarta and Naypyidaw.
21
+ 13. Return to the Wikipedia pages to see which countries those respective capitals belong to: Indonesia, Myanmar.
22
+ Tools:
23
+ 1. Search engine
24
+ 2. Web browser
25
+ 3. Microsoft Excel / Google Sheets
26
+ Final Answer: Indonesia, Myanmar
27
+
28
+ Question 2: Review the chess position provided in the image. It is black's turn. Provide the correct next move for black which guarantees a win. Please provide your response in algebraic notation.
29
+ Steps:
30
+ Step 1: Evaluate the position of the pieces in the chess position
31
+ Step 2: Report the best move available for black: "Rd5"
32
+ Tools:
33
+ 1. Image recognition tools
34
+ Final Answer: Rd5
35
+
36
+ Question 3: Solve the equation x^2 + 5x = -6
37
+ Steps:
38
+ Step 1: Moving all terms to left-hand side until the right-hand side become zero.
39
+ Step 2: Identify the highest power of polynomial in left-hand side. In this case the highest power is 2, this equation is a quadratic equation.
40
+ Step 3: Identify the coefficients of each term in this quadratic equation.
41
+ Step 3: Write quadratic formula and calculate the possible solutions.
42
+ Tools:
43
+ 1. Search engine
44
+ 2. Web browser
45
+ 3. Calculator
46
+ Final Answer: x=-2, x=-3
47
+
48
+ ==========================
49
+ Now, please answer the following question step by step.
requirements.txt CHANGED
@@ -1,4 +1,17 @@
1
  gradio
2
  requests
3
  langchain
4
- langchain-google-genai
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  gradio
2
  requests
3
  langchain
4
+ langchain-community
5
+ langchain-core
6
+ langchain-google-genai
7
+ langchain-huggingface
8
+ langchain-tavily
9
+ langchain-chroma
10
+ langgraph
11
+ huggingface_hub
12
+ supabase
13
+ arxiv
14
+ pymupdf
15
+ wikipedia
16
+ pgvector
17
+ python-dotenv
sample.ipynb CHANGED
@@ -230,7 +230,7 @@
230
  },
231
  {
232
  "cell_type": "code",
233
- "execution_count": 6,
234
  "id": "42263deb",
235
  "metadata": {},
236
  "outputs": [],
@@ -246,14 +246,14 @@
246
  " docs.append(doc)\n",
247
  "\n",
248
  "# insert the documents to the vector database\n",
249
- "try:\n",
250
- " response = (\n",
251
- " supabase.table('documents')\n",
252
- " .insert(docs)\n",
253
- " .execute()\n",
254
- " )\n",
255
- "except Exception as exception:\n",
256
- " print(\"Error inserting data into Supabase:\", exception)"
257
  ]
258
  },
259
  {
@@ -273,17 +273,6 @@
273
  "retriever = vector_store.as_retriever()"
274
  ]
275
  },
276
- {
277
- "cell_type": "code",
278
- "execution_count": null,
279
- "id": "ff5934c3",
280
- "metadata": {},
281
- "outputs": [],
282
- "source": [
283
- "# query = \"What did the president say about Ketanji Brown Jackson\"\n",
284
- "# matched_docs = vector_store.similarity_search(query, 2)"
285
- ]
286
- },
287
  {
288
  "cell_type": "code",
289
  "execution_count": 11,
@@ -307,6 +296,40 @@
307
  "docs = retriever.invoke(query)\n",
308
  "docs[0]"
309
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  }
311
  ],
312
  "metadata": {
 
230
  },
231
  {
232
  "cell_type": "code",
233
+ "execution_count": null,
234
  "id": "42263deb",
235
  "metadata": {},
236
  "outputs": [],
 
246
  " docs.append(doc)\n",
247
  "\n",
248
  "# insert the documents to the vector database\n",
249
+ "#try:\n",
250
+ "# response = (\n",
251
+ "# supabase.table('documents')\n",
252
+ "# .insert(docs)\n",
253
+ "# .execute()\n",
254
+ "# )\n",
255
+ "#except Exception as exception:\n",
256
+ "# print(\"Error inserting data into Supabase:\", exception)"
257
  ]
258
  },
259
  {
 
273
  "retriever = vector_store.as_retriever()"
274
  ]
275
  },
 
 
 
 
 
 
 
 
 
 
 
276
  {
277
  "cell_type": "code",
278
  "execution_count": 11,
 
296
  "docs = retriever.invoke(query)\n",
297
  "docs[0]"
298
  ]
299
+ },
300
+ {
301
+ "cell_type": "markdown",
302
+ "id": "a2e6497a",
303
+ "metadata": {},
304
+ "source": [
305
+ "# Tavily Search"
306
+ ]
307
+ },
308
+ {
309
+ "cell_type": "code",
310
+ "execution_count": 12,
311
+ "id": "a9448c8c",
312
+ "metadata": {},
313
+ "outputs": [],
314
+ "source": [
315
+ "from langchain_community.tools.tavily_search import TavilySearchResults\n",
316
+ "from langchain_community.document_loaders import WikipediaLoader\n",
317
+ "from langchain_community.document_loaders import ArxivLoader"
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "code",
322
+ "execution_count": 13,
323
+ "id": "c3de569e",
324
+ "metadata": {},
325
+ "outputs": [],
326
+ "source": [
327
+ "question_retrieval_tool = create_retriever_tool(\n",
328
+ " vector_store.as_retriever(),\n",
329
+ " name=\"Question retriever\",\n",
330
+ " description=\"Find similar questions in the vector database for the given question.\"\n",
331
+ ")"
332
+ ]
333
  }
334
  ],
335
  "metadata": {
supabase.sql CHANGED
@@ -1,30 +1,38 @@
1
  -- Drop old function
2
  drop function if exists match_documents (vector(1536), int);
3
 
 
 
 
 
 
 
 
 
4
  -- Create a function to search for documents
5
- create function match_documents (
6
- query_embedding vector(1536),
7
- match_count int DEFAULT null,
8
- filter jsonb DEFAULT '{}'
9
- ) returns table (
10
- id bigint,
11
- content text,
12
- metadata jsonb,
13
- similarity float
14
- )
15
- language plpgsql
16
- as $$
17
- #variable_conflict use_column
18
- begin
19
- return query
20
- select
21
- id,
22
- content,
23
- metadata,
24
- 1 - (documents.embedding <=> query_embedding) as similarity
25
- from documents
26
- where metadata @> filter
27
- order by documents.embedding <=> query_embedding
28
- limit match_count;
29
- end;
30
- $$;
 
1
  -- Drop old function
2
  drop function if exists match_documents (vector(1536), int);
3
 
4
+ -- Create a table to store your documents
5
+ create table documents (
6
+ id bigserial primary key,
7
+ content text, -- corresponds to Document.pageContent
8
+ metadata jsonb, -- corresponds to Document.metadata
9
+ embedding vector(768) -- 768 works for Gemini embeddings, change if needed
10
+ );
11
+
12
  -- Create a function to search for documents
13
+ create function match_documents (
14
+ query_embedding vector(768),
15
+ match_count int DEFAULT null,
16
+ filter jsonb DEFAULT '{}'
17
+ ) returns table (
18
+ id bigint,
19
+ content text,
20
+ metadata jsonb,
21
+ similarity float
22
+ )
23
+ language plpgsql
24
+ as $$
25
+ #variable_conflict use_column
26
+ begin
27
+ return query
28
+ select
29
+ id,
30
+ content,
31
+ metadata,
32
+ 1 - (documents.embedding <=> query_embedding) as similarity
33
+ from documents
34
+ where metadata @> filter
35
+ order by documents.embedding <=> query_embedding
36
+ limit match_count;
37
+ end;
38
+ $$;