yplam commited on
Commit
7c21d30
·
1 Parent(s): ca3ab6d

add more tools

Browse files
Files changed (7) hide show
  1. .env.template +1 -0
  2. agent.py +140 -32
  3. app.py +7 -4
  4. requirements.txt +2 -1
  5. tool/files.py +0 -5
  6. tool/math.py +53 -3
  7. tool/youtube.py +10 -3
.env.template CHANGED
@@ -3,4 +3,5 @@ OPENAI_API_KEY=your_openai_api_key_here
3
  OPENAI_API_BASE=https://api.openai.com/v1
4
  OPENAI_PROXY=http://127.0.0.1:7899
5
  PROXY_URL=http://127.0.0.1:7899
 
6
  # Add other configuration variables below
 
3
  OPENAI_API_BASE=https://api.openai.com/v1
4
  OPENAI_PROXY=http://127.0.0.1:7899
5
  PROXY_URL=http://127.0.0.1:7899
6
+ SERPER_API_KEY=
7
  # Add other configuration variables below
agent.py CHANGED
@@ -1,24 +1,22 @@
 
1
  import os
2
- from typing import Annotated, Optional, TypedDict
3
  from dotenv import load_dotenv
4
  from langgraph.graph.message import add_messages
5
  from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage
6
- from tool.files import read_file
7
- from tool.math import divide
8
  from langchain.chat_models import init_chat_model
9
  from langgraph.graph import StateGraph, MessagesState, START, END
10
  from langgraph.prebuilt import ToolNode
 
 
 
 
11
 
12
- from tool.youtube import get_video_id, youtube_transcript
 
13
 
14
  load_dotenv()
15
 
16
- tools = [
17
- get_video_id,
18
- youtube_transcript,
19
- read_file
20
- ]
21
-
22
  llm = init_chat_model(
23
  model="gpt-4o",
24
  model_provider="openai",
@@ -28,11 +26,122 @@ llm = init_chat_model(
28
  openai_proxy=os.getenv("OPENAI_PROXY"),
29
  )
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  llm_with_tools = llm.bind_tools(tools)
32
 
33
 
34
  class State(TypedDict):
35
- input_file: Optional[str]
 
36
  messages: Annotated[list[AnyMessage], add_messages]
37
  answer: str
38
 
@@ -41,22 +150,19 @@ def should_continue(state: State):
41
  last_message = messages[-1]
42
  if last_message.tool_calls:
43
  return "tools"
44
- return END
45
 
46
- def format_answer(last_message: str):
47
- system_message_content = "You are a general AI assistant. \
48
- Check the user's answer and validate and format it with the following rules: \
49
- The output should be in the following format: \
50
  FINAL ANSWER: [YOUR FINAL ANSWER]. \
51
- YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. \
52
- If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. \
53
- If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. \
54
- If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. \
55
- Your answer should only start with 'FINAL ANSWER: ', then follows with the answer. "
56
  system_message = SystemMessage(content=system_message_content)
57
- messages = [system_message] + [last_message]
58
- answer = llm_with_tools.invoke(messages)
59
- return answer.content
 
60
 
61
  def agent(state: State):
62
  system_message_content = "You are a general AI assistant. I will ask you a question. \
@@ -67,12 +173,11 @@ def agent(state: State):
67
  If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. \
68
  If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. \
69
  Your answer should only start with 'FINAL ANSWER: ', then follows with the answer. "
70
- if state["input_file"]:
71
- system_message_content += f"\nYou are given a file: {state['input_file']}"
72
  system_message = SystemMessage(content=system_message_content)
73
  messages = [system_message] + state["messages"]
74
- answer = llm_with_tools.invoke(messages)
75
- return {"messages": [answer], "answer": format_answer(answer.content)}
76
 
77
 
78
  class Agent:
@@ -81,12 +186,15 @@ class Agent:
81
 
82
  tool_node = ToolNode(tools)
83
  graph_builder = StateGraph(State)
 
84
  graph_builder.add_node("agent", agent)
85
  graph_builder.add_node("tools", tool_node)
86
-
 
87
  graph_builder.add_edge(START, "agent")
88
- graph_builder.add_conditional_edges("agent", should_continue, ["tools", END])
89
  graph_builder.add_edge("tools", "agent")
 
90
  self.graph = graph_builder.compile()
91
  try:
92
  # Save graph visualization as PNG file
@@ -99,6 +207,6 @@ class Agent:
99
  print(f"Could not save graph visualization: {str(e)}")
100
  pass
101
 
102
- def __call__(self, question: str, file_name: str|None) -> str:
103
- result = self.graph.invoke({"input_file": file_name, "messages": [HumanMessage(content=question)]})
104
  return result["answer"]
 
1
+ import json
2
  import os
3
+ from typing import Annotated, Optional, TypedDict, List
4
  from dotenv import load_dotenv
5
  from langgraph.graph.message import add_messages
6
  from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage
 
 
7
  from langchain.chat_models import init_chat_model
8
  from langgraph.graph import StateGraph, MessagesState, START, END
9
  from langgraph.prebuilt import ToolNode
10
+ import requests
11
+ from langchain_community.document_loaders import WikipediaLoader
12
+ from langchain_community.document_loaders import WebBaseLoader
13
+ from langchain_core.tools import tool
14
 
15
+ from tool.math import add, divide, multiply, subtract, modulus
16
+ from tool.youtube import youtube_transcript
17
 
18
  load_dotenv()
19
 
 
 
 
 
 
 
20
  llm = init_chat_model(
21
  model="gpt-4o",
22
  model_provider="openai",
 
26
  openai_proxy=os.getenv("OPENAI_PROXY"),
27
  )
28
 
29
+ @tool
30
+ def analyze_image_by_url(image_url: str, prompt: str) -> str:
31
+ """Using VL model to analyze the image in image_url using the prompt, and return the answer.
32
+ Args:
33
+ image_url: The url of the image to analyze
34
+ prompt: The prompt to use to analyze the image
35
+ Returns:
36
+ The answer to the prompt
37
+ """
38
+ if image_url is None:
39
+ return ""
40
+
41
+ response = llm.invoke([{
42
+ "role": "user",
43
+ "content": [
44
+ {"type": "text", "text": prompt},
45
+ {
46
+ "type": "image_url",
47
+ "image_url": {
48
+ "url": image_url
49
+ }
50
+ }
51
+ ]
52
+ }])
53
+ print(f"Response: {response.content}")
54
+ return response.content
55
+
56
+ def read_file_by_path(file_path: str) -> str:
57
+ """Read the file in file_path and return the content."""
58
+ print(f"Reading file: {file_path}")
59
+ if file_path is None:
60
+ return ""
61
+ with open(file_path, "r") as f:
62
+ return f.read()
63
+
64
+ @tool
65
+ def read_file_by_url(file_url: str) -> str:
66
+ """Read the file in file_url and return the content.
67
+ Args:
68
+ file_url: The url of the file to read
69
+ Returns:
70
+ The raw content of the file
71
+ """
72
+ print(f"Reading file: {file_url}")
73
+ if file_url is None:
74
+ return ""
75
+ response = requests.get(file_url)
76
+ return response.content
77
+
78
+ @tool
79
+ def load_webpage_from_url(url: str) -> str:
80
+ """Load the webpage from the given url and return the content.
81
+ Args:
82
+ url: The url of the webpage to load
83
+ Returns:
84
+ The content of the webpage
85
+ """
86
+ print(f"Loading webpage from: {url}")
87
+ return WebBaseLoader(url).load()
88
+
89
+ @tool
90
+ def load_wikipedia(query: str) -> str:
91
+ """Load Wikipedia for the given query and return the content.
92
+ Args:
93
+ query: The query to search Wikipedia for
94
+ Returns:
95
+ The content of the Wikipedia page
96
+ """
97
+ print(f"Loading Wikipedia for: {query}")
98
+ return WikipediaLoader(query=query, load_max_docs=1).load()
99
+
100
+ @tool
101
+ def search_google(query: str) -> str:
102
+ """Search Google for the given query and return the result.
103
+ Args:
104
+ query: The query to search Google for
105
+ Returns:
106
+ The result of the Google search
107
+ """
108
+ print(f"Searching Google for: {query}")
109
+ url = "https://google.serper.dev/search"
110
+
111
+ payload = json.dumps({
112
+ "q": query
113
+ })
114
+ headers = {
115
+ 'X-API-KEY': os.getenv("SERPER_API_KEY"),
116
+ 'Content-Type': 'application/json'
117
+ }
118
+
119
+ response = requests.request("POST", url, headers=headers, data=payload)
120
+ print(f"Google search result for: {query}")
121
+ print(response.text)
122
+ return response.text
123
+
124
+ tools = [
125
+ youtube_transcript,
126
+ analyze_image_by_url,
127
+ read_file_by_path,
128
+ read_file_by_url,
129
+ load_webpage_from_url,
130
+ load_wikipedia,
131
+ search_google,
132
+ multiply,
133
+ add,
134
+ subtract,
135
+ divide,
136
+ modulus
137
+ ]
138
+
139
  llm_with_tools = llm.bind_tools(tools)
140
 
141
 
142
  class State(TypedDict):
143
+ local_file_path: Optional[str]
144
+ file_url: Optional[str]
145
  messages: Annotated[list[AnyMessage], add_messages]
146
  answer: str
147
 
 
150
  last_message = messages[-1]
151
  if last_message.tool_calls:
152
  return "tools"
153
+ return "format_answer"
154
 
155
+ def format_answer(state: State):
156
+ system_message_content = "You are a AI assistant to extract the answer from the user's answer. \
157
+ The user's answer should be in the following format: \
 
158
  FINAL ANSWER: [YOUR FINAL ANSWER]. \
159
+ Your need to extract and only return the answer. If you don't find the answer, output 'N/A' \
160
+ Remove '.' from the end of the answer."
 
 
 
161
  system_message = SystemMessage(content=system_message_content)
162
+ messages = [system_message] + [state["messages"][-1]]
163
+ answer = llm.invoke(messages)
164
+ return {"answer": answer.content}
165
+
166
 
167
  def agent(state: State):
168
  system_message_content = "You are a general AI assistant. I will ask you a question. \
 
173
  If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. \
174
  If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. \
175
  Your answer should only start with 'FINAL ANSWER: ', then follows with the answer. "
176
+ if state["local_file_path"]:
177
+ system_message_content += f"\nYou can only read files I provide you. You are given a file path related to the question: {state['local_file_path']}, and the online url related to the same file: {state['file_url']}"
178
  system_message = SystemMessage(content=system_message_content)
179
  messages = [system_message] + state["messages"]
180
+ return {"messages": [llm_with_tools.invoke(messages)]}
 
181
 
182
 
183
  class Agent:
 
186
 
187
  tool_node = ToolNode(tools)
188
  graph_builder = StateGraph(State)
189
+
190
  graph_builder.add_node("agent", agent)
191
  graph_builder.add_node("tools", tool_node)
192
+ graph_builder.add_node("format_answer", format_answer)
193
+
194
  graph_builder.add_edge(START, "agent")
195
+ graph_builder.add_conditional_edges("agent", should_continue, ["tools", "format_answer"])
196
  graph_builder.add_edge("tools", "agent")
197
+ graph_builder.add_edge("format_answer", END)
198
  self.graph = graph_builder.compile()
199
  try:
200
  # Save graph visualization as PNG file
 
207
  print(f"Could not save graph visualization: {str(e)}")
208
  pass
209
 
210
+ def __call__(self, question: str, local_file_path: str|None, file_url: str|None) -> str:
211
+ result = self.graph.invoke({"local_file_path": local_file_path, "file_url": file_url, "messages": [HumanMessage(content=question)]})
212
  return result["answer"]
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from dotenv import load_dotenv
3
  import gradio as gr
4
  import requests
@@ -15,7 +16,6 @@ def download_file(filename: str) -> str:
15
  """
16
  Downloads a file from the API and returns the path to the local file.
17
  """
18
- return None
19
  if filename is None or filename == "":
20
  return None
21
  print(f"Downloading file: {filename}")
@@ -90,15 +90,18 @@ def run_all( username: str|None, submit: bool = True):
90
  for item in questions_data:
91
  task_id = item.get("task_id")
92
  question_text = item.get("question")
93
- file_name = item.get("file_name") or ""
94
- file_path = download_file(file_name)
 
 
 
95
  if not task_id or question_text is None:
96
  print(f"Skipping item with missing task_id or question: {item}")
97
  continue
98
  try:
99
  print("-"*100)
100
  print(f"Running agent on task {task_id}: {question_text}")
101
- submitted_answer = agent(question_text, "")
102
  print("-"*30)
103
  print(f"Submitted answer: {submitted_answer}")
104
  print("-"*100)
 
1
  import os
2
+ import tempfile
3
  from dotenv import load_dotenv
4
  import gradio as gr
5
  import requests
 
16
  """
17
  Downloads a file from the API and returns the path to the local file.
18
  """
 
19
  if filename is None or filename == "":
20
  return None
21
  print(f"Downloading file: {filename}")
 
90
  for item in questions_data:
91
  task_id = item.get("task_id")
92
  question_text = item.get("question")
93
+ local_file_path = None
94
+ file_url = None
95
+ if item.get("file_name"):
96
+ local_file_path = download_file(task_id)
97
+ file_url = f"{DEFAULT_API_URL}/files/{task_id}"
98
  if not task_id or question_text is None:
99
  print(f"Skipping item with missing task_id or question: {item}")
100
  continue
101
  try:
102
  print("-"*100)
103
  print(f"Running agent on task {task_id}: {question_text}")
104
+ submitted_answer = agent(question_text, local_file_path, file_url)
105
  print("-"*30)
106
  print(f"Submitted answer: {submitted_answer}")
107
  print("-"*100)
requirements.txt CHANGED
@@ -5,4 +5,5 @@ langchain_openai
5
  langchain
6
  python-dotenv
7
  youtube_transcript_api
8
- pandas
 
 
5
  langchain
6
  python-dotenv
7
  youtube_transcript_api
8
+ pandas
9
+ langchain_community
tool/files.py DELETED
@@ -1,5 +0,0 @@
1
- def read_file(file_path: str) -> str:
2
- """Reads the content of a file and returns it as a string."""
3
- print(f"Reading file: {file_path}")
4
- with open(file_path, 'r') as file:
5
- return file.read()
 
 
 
 
 
 
tool/math.py CHANGED
@@ -1,3 +1,53 @@
1
- def divide(a: int, b: int) -> float:
2
- """Divide a and b - for Master Wayne's occasional calculations."""
3
- return a / b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from langchain_core.tools import tool
3
+
4
+ @tool
5
+ def multiply(a: int, b: int) -> int:
6
+ """Multiply two numbers.
7
+ Args:
8
+ a: first int
9
+ b: second int
10
+ """
11
+ return a * b
12
+
13
+ @tool
14
+ def add(a: int, b: int) -> int:
15
+ """Add two numbers.
16
+
17
+ Args:
18
+ a: first int
19
+ b: second int
20
+ """
21
+ return a + b
22
+
23
+ @tool
24
+ def subtract(a: int, b: int) -> int:
25
+ """Subtract two numbers.
26
+
27
+ Args:
28
+ a: first int
29
+ b: second int
30
+ """
31
+ return a - b
32
+
33
+ @tool
34
+ def divide(a: int, b: int) -> int:
35
+ """Divide two numbers.
36
+
37
+ Args:
38
+ a: first int
39
+ b: second int
40
+ """
41
+ if b == 0:
42
+ raise ValueError("Cannot divide by zero.")
43
+ return a / b
44
+
45
+ @tool
46
+ def modulus(a: int, b: int) -> int:
47
+ """Get the modulus of two numbers.
48
+
49
+ Args:
50
+ a: first int
51
+ b: second int
52
+ """
53
+ return a % b
tool/youtube.py CHANGED
@@ -1,12 +1,18 @@
1
  import os
2
  from youtube_transcript_api import YouTubeTranscriptApi
3
  from youtube_transcript_api.proxies import GenericProxyConfig
 
4
 
5
- def youtube_transcript(video_id: str) -> str:
 
6
  """
7
- Extracts the transcript from a YouTube video id
 
 
 
 
8
  """
9
- print(f"Extracting transcript from: {video_id}")
10
  try:
11
  ytt_api = YouTubeTranscriptApi()
12
  if os.getenv("PROXY_URL"):
@@ -16,6 +22,7 @@ def youtube_transcript(video_id: str) -> str:
16
  https_url=os.getenv("PROXY_URL"),
17
  )
18
  )
 
19
  transcript = ytt_api.fetch(video_id)
20
  print(f"Transcript: {transcript}")
21
  return transcript
 
1
  import os
2
  from youtube_transcript_api import YouTubeTranscriptApi
3
  from youtube_transcript_api.proxies import GenericProxyConfig
4
+ from langchain_core.tools import tool
5
 
6
+ @tool
7
+ def youtube_transcript(video_url: str) -> str:
8
  """
9
+ Extracts the transcript from a YouTube video url
10
+ Args:
11
+ video_url: The url of the YouTube video
12
+ Returns:
13
+ The transcript of the YouTube video
14
  """
15
+ print(f"Extracting transcript from: {video_url}")
16
  try:
17
  ytt_api = YouTubeTranscriptApi()
18
  if os.getenv("PROXY_URL"):
 
22
  https_url=os.getenv("PROXY_URL"),
23
  )
24
  )
25
+ video_id = get_video_id(video_url)
26
  transcript = ytt_api.fetch(video_id)
27
  print(f"Transcript: {transcript}")
28
  return transcript