sampsong commited on
Commit
6c0615e
·
1 Parent(s): d85f139

fix web search, pass image url to llm

Browse files
Files changed (3) hide show
  1. Agents/agent.py +5 -2
  2. Tools/tools.py +26 -5
  3. app.py +9 -3
Agents/agent.py CHANGED
@@ -19,7 +19,7 @@ from langchain_core.messages import (
19
  convert_to_messages,
20
  )
21
  from pydantic import BaseModel
22
- from Tools.tools import webSearch, arxivSearch, wikiSearch,add,multiply,divide,substract, modulus
23
  from langchain_core.messages import SystemMessage, HumanMessage
24
  from dotenv import load_dotenv
25
  from supabase.client import Client, create_client
@@ -120,7 +120,10 @@ tools = [
120
  add,
121
  substract,
122
  divide,
123
- modulus
 
 
 
124
  ]
125
 
126
  def tools_condition1(
 
19
  convert_to_messages,
20
  )
21
  from pydantic import BaseModel
22
+ from Tools.tools import webSearch, youtubeVideoTranscript, arxivSearch, wikiSearch,add,multiply,divide,substract,modulus,power,count_substring
23
  from langchain_core.messages import SystemMessage, HumanMessage
24
  from dotenv import load_dotenv
25
  from supabase.client import Client, create_client
 
120
  add,
121
  substract,
122
  divide,
123
+ modulus,
124
+ power,
125
+ count_substring,
126
+ youtubeVideoTranscript
127
  ]
128
 
129
  def tools_condition1(
Tools/tools.py CHANGED
@@ -93,16 +93,17 @@ def arxivSearch(searchQuery:str) -> str:
93
  ])
94
  return {"arxiv_result": formatted_results}
95
 
 
96
  @tool
97
  def webSearch(searchQuery:str) -> str:
98
  """
99
- search the web using Tavily to get three matching results
100
 
101
  args:
102
  searchQuery: search query
103
  """
104
- print("web_search")
105
- search_results = TavilySearchResults(max_results=3).invoke(query=searchQuery)
106
  formatted_results = "\n\n--\n\n".join(
107
  [
108
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page","")}"/>\n{doc.page_content}\n</Document>'
@@ -112,9 +113,9 @@ def webSearch(searchQuery:str) -> str:
112
  return {"web_search": formatted_results}
113
 
114
  @tool
115
- def youtubeTranscript(youtubeURL:str) -> str:
116
  """
117
- obtain youtube transcript by passing in the youtube url
118
 
119
  args:
120
  youtubeURL: youtube url to pull out the transcript
@@ -129,5 +130,25 @@ def youtubeTranscript(youtubeURL:str) -> str:
129
  formatted_results = "\n\n".join(map(repr, loader.load()))
130
  return {"Youtube transcript":formatted_results}
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
 
 
93
  ])
94
  return {"arxiv_result": formatted_results}
95
 
96
+
97
  @tool
98
  def webSearch(searchQuery:str) -> str:
99
  """
100
+ search the web using Tavily to get 2 matching results
101
 
102
  args:
103
  searchQuery: search query
104
  """
105
+ print("web_search: {searchQuery}")
106
+ search_results = TavilySearchResults(max_results=2).invoke(input=searchQuery)
107
  formatted_results = "\n\n--\n\n".join(
108
  [
109
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page","")}"/>\n{doc.page_content}\n</Document>'
 
113
  return {"web_search": formatted_results}
114
 
115
  @tool
116
+ def youtubeVideoTranscript(youtubeURL:str) -> str:
117
  """
118
+ Get youtube video transcript by passing in the youtube url
119
 
120
  args:
121
  youtubeURL: youtube url to pull out the transcript
 
130
  formatted_results = "\n\n".join(map(repr, loader.load()))
131
  return {"Youtube transcript":formatted_results}
132
 
133
+ @tool
134
+ def power(a: float, b: float) -> float:
135
+ """
136
+ Get the power of two numbers.
137
+ Args:
138
+ a (float): the first number
139
+ b (float): the second number
140
+ """
141
+ return a**b
142
+
143
+ @tool
144
+ def count_substring(substring:str, text:str) -> int:
145
+ """
146
+ Get the number of occurences of a substring within some text. Useful for 'How many (substring) are in (text)?'
147
+ Args:
148
+ substring (str): the substring to check for.
149
+ text (str): the text to search through.
150
+ """
151
+ return text.count(substring)
152
+
153
 
154
 
app.py CHANGED
@@ -16,6 +16,7 @@ langfuse_handler = CallbackHandler()
16
  testMode = bool(os.getenv("TestMode"))
17
  langFuseOn = bool(os.getenv("LangFuseOn"))
18
  agentType = os.getenv("AgentType")
 
19
 
20
  # --- Basic Agent Definition ---
21
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
@@ -39,9 +40,13 @@ class BasicAgent:
39
  def __init__(self):
40
  print("BasicAgent initialized.")
41
  self.graph = build_graph()
42
- def __call__(self, question: str) -> str:
43
  print(f"Agent received question (first 50 chars): {question[:50]}...")
44
- messages = [HumanMessage(content=question)]
 
 
 
 
45
 
46
  if(not langFuseOn):
47
  print("no langfuse")
@@ -140,13 +145,14 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
140
  print(f"Running agent on {len(questions_data)} questions...")
141
  for item in questions_data:
142
  task_id = item.get("task_id")
 
143
  question_text = item.get("question")
144
  print(f"running on Question data {question_text}")
145
  if not task_id or question_text is None:
146
  print(f"Skipping item with missing task_id or question: {item}")
147
  continue
148
  try:
149
- submitted_answer = agent(question_text)
150
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
151
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
152
  except Exception as e:
 
16
  testMode = bool(os.getenv("TestMode"))
17
  langFuseOn = bool(os.getenv("LangFuseOn"))
18
  agentType = os.getenv("AgentType")
19
+ gaiaValidationURL = os.getenv("GaiaValidationURL")
20
 
21
  # --- Basic Agent Definition ---
22
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
 
40
  def __init__(self):
41
  print("BasicAgent initialized.")
42
  self.graph = build_graph()
43
+ def __call__(self, question: str, imageURL: str="") -> str:
44
  print(f"Agent received question (first 50 chars): {question[:50]}...")
45
+ if(imageURL.strip == ""):
46
+ messages = [HumanMessage(content=question)]
47
+ else:
48
+ formattedImageURL = gaiaValidationURL + imageURL
49
+ messages = [HumanMessage(content=question, additional_kwargs={imageURL:formattedImageURL })]
50
 
51
  if(not langFuseOn):
52
  print("no langfuse")
 
145
  print(f"Running agent on {len(questions_data)} questions...")
146
  for item in questions_data:
147
  task_id = item.get("task_id")
148
+ file_name = item.get("file_name")
149
  question_text = item.get("question")
150
  print(f"running on Question data {question_text}")
151
  if not task_id or question_text is None:
152
  print(f"Skipping item with missing task_id or question: {item}")
153
  continue
154
  try:
155
+ submitted_answer = agent(question_text,file_name)
156
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
157
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
158
  except Exception as e: