Paperbag commited on
Commit
39b1e37
·
1 Parent(s): 45386f2

feat: Update `analyze_image` and `analyze_video` tool descriptions and system prompt rules to enhance multimedia processing.

Browse files
Files changed (4) hide show
  1. __pycache__/agent.cpython-39.pyc +0 -0
  2. agent.py +79 -15
  3. app.py +2 -1
  4. requirements.txt +2 -0
__pycache__/agent.cpython-39.pyc CHANGED
Binary files a/__pycache__/agent.cpython-39.pyc and b/__pycache__/agent.cpython-39.pyc differ
 
agent.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import datetime
3
  import subprocess
4
  import tempfile
 
5
  from typing import TypedDict, List, Dict, Any, Optional, Union
6
  from langchain_core import tools
7
  from langgraph.graph import StateGraph, START, END
@@ -15,6 +16,8 @@ from groq import Groq
15
  from langchain_groq import ChatGroq
16
  from langchain_community.document_loaders.image import UnstructuredImageLoader
17
  from langchain_community.document_loaders import WebBaseLoader
 
 
18
  import base64
19
 
20
  try:
@@ -22,6 +25,8 @@ try:
22
  except ImportError:
23
  cv2 = None
24
 
 
 
25
  whisper_model = None
26
  def get_whisper():
27
  global whisper_model
@@ -48,9 +53,59 @@ model = ChatGroq(
48
  max_tokens=None,
49
  timeout=None,
50
  max_retries=2,
51
- # other params...
52
  )
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  @tool
55
  def web_search(keywords: str) -> str:
56
  """
@@ -105,12 +160,13 @@ def wiki_search(query: str) -> str:
105
  @tool
106
  def analyze_image(image_path: str, question: str) -> str:
107
  """
108
- Analyzes an image to answer a specific question.
109
- Use this tool when you need to extract visual information from an image file.
 
110
 
111
  Args:
112
  image_path: The local path or URL to the image file.
113
- question: The specific question to ask about the image.
114
  """
115
  try:
116
  # If it's a local file, we encode it to base64
@@ -154,12 +210,13 @@ def analyze_audio(audio_path: str, question: str) -> str:
154
  @tool
155
  def analyze_video(video_path: str, question: str) -> str:
156
  """
157
- Analyzes a video file to answer questions about its content.
158
- Extracts key frames and describes what is happening.
 
159
 
160
  Args:
161
  video_path: The local path to the video file.
162
- question: The specific question to ask about the video.
163
  """
164
  if cv2 is None:
165
  return "Error: cv2 is not installed. Please install opencv-python."
@@ -321,6 +378,8 @@ def restart_required(state: AgentState) -> AgentState:
321
  tools = [web_search, wiki_search, analyze_image, analyze_audio, analyze_video, read_url, run_python_script, read_document]
322
  tools_by_name = {tool.name: tool for tool in tools}
323
  model_with_tools = model.bind_tools(tools)
 
 
324
 
325
  def answer_message(state: AgentState) -> AgentState:
326
  messages = state["messages"]
@@ -334,11 +393,12 @@ def answer_message(state: AgentState) -> AgentState:
334
  TODAY'S EXACT DATE is {current_date}. Keep this in mind for all time-sensitive queries.
335
 
336
  CRITICAL RULES FOR SEARCH & TOOLS:
337
- 1. If a file is attached, use the appropriate tool (run_python_script, read_document, analyze_image, analyze_audio, analyze_video) to answer the question based on the file content.
338
- 2. Use run_python_script freely to process data (pandas), read complex documents (.xlsx, .pdf), or do heavy math calculations.
339
- 3. When using tools like web_search or wiki_search, do not blindly search the entire question. Extract the core entities.
340
- 4. If the first search result doesn't contain the answer, THINK step-by-step, refine your search query (e.g., use synonyms, or search for broader concepts), and search again.
341
- 5. Cross-reference facts if they seem ambiguous.
 
342
 
343
  Do not include any thought process before answering the question, and only response exactly what was being asked of you.
344
  If you are not able to provide an answer, use tools or state the limitation that you're facing instead.
@@ -358,8 +418,12 @@ def answer_message(state: AgentState) -> AgentState:
358
  draft_response = None
359
 
360
  for step in range(max_steps):
 
 
 
 
361
  print(f"--- ReAct Step {step + 1} ---")
362
- ai_msg = model_with_tools.invoke(messages)
363
  messages.append(ai_msg)
364
 
365
  # Check if the model requested tools
@@ -390,7 +454,7 @@ def answer_message(state: AgentState) -> AgentState:
390
  print("Max reasoning steps reached. Forcing answer extraction.")
391
  forced_msg = HumanMessage(content="You have reached the maximum reasoning steps. Please provide your best final answer based on the current context without any more tool calls.")
392
  messages.append(forced_msg)
393
- draft_response = model.invoke(messages)
394
 
395
  # Third pass: strict GAIA formatting extraction
396
  formatting_sys = SystemMessage(
@@ -403,7 +467,7 @@ def answer_message(state: AgentState) -> AgentState:
403
  "If it is a name or word, just return the exact string. If a list, return only the comma-separated list."
404
  )
405
  )
406
- final_response = model.invoke([formatting_sys, HumanMessage(content=draft_response.content)])
407
  print(f"Draft response: {draft_response.content}")
408
  print(f"Strict Final response: {final_response.content}")
409
 
 
2
  import datetime
3
  import subprocess
4
  import tempfile
5
+ import time
6
  from typing import TypedDict, List, Dict, Any, Optional, Union
7
  from langchain_core import tools
8
  from langgraph.graph import StateGraph, START, END
 
16
  from langchain_groq import ChatGroq
17
  from langchain_community.document_loaders.image import UnstructuredImageLoader
18
  from langchain_community.document_loaders import WebBaseLoader
19
+ from langchain_openai import ChatOpenAI
20
+ from langchain_google_genai import ChatGoogleGenerativeAI
21
  import base64
22
 
23
  try:
 
25
  except ImportError:
26
  cv2 = None
27
 
28
+ # os.environ["USER_AGENT"] = "gaia-agent/1.0"
29
+
30
  whisper_model = None
31
  def get_whisper():
32
  global whisper_model
 
53
  max_tokens=None,
54
  timeout=None,
55
  max_retries=2,
 
56
  )
57
 
58
+ # OpenRouter Fallback Model (used when Groq hits rate limits)
59
+ openrouter_model = ChatOpenAI(
60
+ model="meta-llama/llama-3.3-70b-instruct",
61
+ openai_api_key=os.getenv("OPENROUTER_API_KEY"),
62
+ openai_api_base="https://openrouter.ai/api/v1",
63
+ temperature=0,
64
+ )
65
+
66
+ # Google AI Studio Fallback Model (Gemini)
67
+ gemini_model = ChatGoogleGenerativeAI(
68
+ model="gemini-1.5-pro",
69
+ google_api_key=os.getenv("GOOGLE_API_KEY"),
70
+ temperature=0,
71
+ )
72
+
73
+ def smart_invoke(msgs, use_tools=False):
74
+ """
75
+ Tiered fallback: Groq -> OpenRouter -> Google AI Studio.
76
+ Retries next tier if a 429 (rate limit) or server-side error occurs.
77
+ """
78
+ primary = model_with_tools if use_tools else model
79
+ secondary = openrouter_with_tools if use_tools else openrouter_model
80
+ tertiary = gemini_with_tools if use_tools else gemini_model
81
+
82
+ tiers = [
83
+ {"name": "Groq", "model": primary, "key": "GROQ_API_KEY"},
84
+ {"name": "OpenRouter", "model": secondary, "key": "OPENROUTER_API_KEY"},
85
+ {"name": "Gemini", "model": tertiary, "key": "GOOGLE_API_KEY"},
86
+ ]
87
+
88
+ last_exception = None
89
+ for tier in tiers:
90
+ if not os.getenv(tier["key"]):
91
+ continue # Skip if no API key
92
+
93
+ try:
94
+ return tier["model"].invoke(msgs)
95
+ except Exception as e:
96
+ err_str = str(e).lower()
97
+ # Catch rate limits or generic temporary server failures
98
+ if any(x in err_str for x in ["rate_limit", "429", "500", "503", "overloaded"]):
99
+ print(f"--- {tier['name']} Error: {e}. Falling back... ---")
100
+ last_exception = e
101
+ continue
102
+ raise e
103
+
104
+ if last_exception:
105
+ print("CRITICAL: All fallback tiers failed.")
106
+ raise last_exception
107
+ return None
108
+
109
  @tool
110
  def web_search(keywords: str) -> str:
111
  """
 
160
  @tool
161
  def analyze_image(image_path: str, question: str) -> str:
162
  """
163
+ EXTERNAL SIGHT API: Sends an image path to a Vision Model to answer a specific question.
164
+ YOU MUST CALL THIS TOOL ANY TIME an image (.png, .jpg, .jpeg) is attached to the prompt.
165
+ NEVER claim you cannot see images. Use this tool instead.
166
 
167
  Args:
168
  image_path: The local path or URL to the image file.
169
+ question: Specific question describing what you want the vision model to look for.
170
  """
171
  try:
172
  # If it's a local file, we encode it to base64
 
210
  @tool
211
  def analyze_video(video_path: str, question: str) -> str:
212
  """
213
+ EXTERNAL SIGHT/HEARING API: Sends a video file to an external Vision/Audio model.
214
+ YOU MUST CALL THIS TOOL ANY TIME a video (.mp4, .avi) is attached to the prompt.
215
+ NEVER claim you cannot analyze videos. Use this tool instead.
216
 
217
  Args:
218
  video_path: The local path to the video file.
219
+ question: Specific question describing what you want to extract from the video.
220
  """
221
  if cv2 is None:
222
  return "Error: cv2 is not installed. Please install opencv-python."
 
378
  tools = [web_search, wiki_search, analyze_image, analyze_audio, analyze_video, read_url, run_python_script, read_document]
379
  tools_by_name = {tool.name: tool for tool in tools}
380
  model_with_tools = model.bind_tools(tools)
381
+ openrouter_with_tools = openrouter_model.bind_tools(tools)
382
+ gemini_with_tools = gemini_model.bind_tools(tools)
383
 
384
  def answer_message(state: AgentState) -> AgentState:
385
  messages = state["messages"]
 
393
  TODAY'S EXACT DATE is {current_date}. Keep this in mind for all time-sensitive queries.
394
 
395
  CRITICAL RULES FOR SEARCH & TOOLS:
396
+ 1. If an image, video, or audio file is attached, YOU MUST NOT SAY "I don't have access to analyze..." or "I cannot see". YOU ARE NOT BLIND. You have external APIs (analyze_image, analyze_video, analyze_audio) that will act as your eyes and ears! ALWAYS invoke these tools immediately to get descriptions!
397
+ 2. If a text/data file is attached, use the appropriate tool (run_python_script, read_document) to analyze the file content.
398
+ 3. Use run_python_script freely to process data (pandas), read complex documents (.xlsx, .pdf), or do heavy math calculations.
399
+ 4. When using tools like web_search or wiki_search, do not blindly search the entire question. Extract the core entities.
400
+ 5. If the first search result doesn't contain the answer, THINK step-by-step, refine your search query (e.g., use synonyms, or search for broader concepts), and search again.
401
+ 6. Cross-reference facts if they seem ambiguous.
402
 
403
  Do not include any thought process before answering the question, and only response exactly what was being asked of you.
404
  If you are not able to provide an answer, use tools or state the limitation that you're facing instead.
 
418
  draft_response = None
419
 
420
  for step in range(max_steps):
421
+ if step > 0:
422
+ # Prevents Groq API Request/Tokens Per Minute exceptions when deep reasoning
423
+ time.sleep(4)
424
+
425
  print(f"--- ReAct Step {step + 1} ---")
426
+ ai_msg = smart_invoke(messages, use_tools=True)
427
  messages.append(ai_msg)
428
 
429
  # Check if the model requested tools
 
454
  print("Max reasoning steps reached. Forcing answer extraction.")
455
  forced_msg = HumanMessage(content="You have reached the maximum reasoning steps. Please provide your best final answer based on the current context without any more tool calls.")
456
  messages.append(forced_msg)
457
+ draft_response = smart_invoke(messages, use_tools=False)
458
 
459
  # Third pass: strict GAIA formatting extraction
460
  formatting_sys = SystemMessage(
 
467
  "If it is a name or word, just return the exact string. If a list, return only the comma-separated list."
468
  )
469
  )
470
+ final_response = smart_invoke([formatting_sys, HumanMessage(content=draft_response.content)], use_tools=False)
471
  print(f"Draft response: {draft_response.content}")
472
  print(f"Strict Final response: {final_response.content}")
473
 
app.py CHANGED
@@ -53,7 +53,8 @@ def file_extract(local_file_path, task_id):
53
  logger.warning(f"Could not download file '{local_file_path}' for task_id {task_id}. Make sure you accepted GAIA terms on HF and set HF_TOKEN.")
54
  return None
55
 
56
- def run_and_submit_all( profile: gr.OAuthProfile | None):
 
57
  """
58
  Fetches all questions, runs the BasicAgent on them, submits all answers,
59
  and displays the results.
 
53
  logger.warning(f"Could not download file '{local_file_path}' for task_id {task_id}. Make sure you accepted GAIA terms on HF and set HF_TOKEN.")
54
  return None
55
 
56
+ from typing import Optional
57
+ def run_and_submit_all(profile: Optional[gr.OAuthProfile] = None):
58
  """
59
  Fetches all questions, runs the BasicAgent on them, submits all answers,
60
  and displays the results.
requirements.txt CHANGED
@@ -25,3 +25,5 @@ opencv-python
25
  beautifulsoup4
26
  PyPDF2
27
  openai-whisper
 
 
 
25
  beautifulsoup4
26
  PyPDF2
27
  openai-whisper
28
+ langchain-openai
29
+ langchain-google-genai