Paperbag commited on
Commit
53e9378
·
1 Parent(s): 60d960d

feat: Upgrade Gemini model, reorder model fallback tiers, enhance error handling, and add image tool forcing with a new test.

Browse files
__pycache__/agent.cpython-312.pyc CHANGED
Binary files a/__pycache__/agent.cpython-312.pyc and b/__pycache__/agent.cpython-312.pyc differ
 
agent.py CHANGED
@@ -55,7 +55,7 @@ model = ChatGroq(
55
  max_retries=2,
56
  )
57
 
58
- # OpenRouter Fallback Model (used when Groq hits rate limits)
59
  openrouter_model = ChatOpenAI(
60
  model="meta-llama/llama-3.3-70b-instruct",
61
  openai_api_key=os.getenv("OPENROUTER_API_KEY"),
@@ -65,7 +65,7 @@ openrouter_model = ChatOpenAI(
65
 
66
  # Google AI Studio Fallback Model (Gemini)
67
  gemini_model = ChatGoogleGenerativeAI(
68
- model="gemini-1.5-pro",
69
  # google_api_key is automatically picked up from GOOGLE_API_KEY environment variable
70
  temperature=0,
71
  )
@@ -80,9 +80,9 @@ def smart_invoke(msgs, use_tools=False):
80
  tertiary = gemini_with_tools if use_tools else gemini_model
81
 
82
  tiers = [
83
- {"name": "Groq", "model": primary, "key": "GROQ_API_KEY"},
84
  {"name": "OpenRouter", "model": secondary, "key": "OPENROUTER_API_KEY"},
85
  {"name": "Gemini", "model": tertiary, "key": "GOOGLE_API_KEY"},
 
86
  ]
87
 
88
  last_exception = None
@@ -94,8 +94,8 @@ def smart_invoke(msgs, use_tools=False):
94
  return tier["model"].invoke(msgs)
95
  except Exception as e:
96
  err_str = str(e).lower()
97
- # Catch rate limits or generic temporary server failures
98
- if any(x in err_str for x in ["rate_limit", "429", "500", "503", "overloaded"]):
99
  print(f"--- {tier['name']} Error: {e}. Falling back... ---")
100
  last_exception = e
101
  continue
@@ -173,8 +173,13 @@ def analyze_image(image_path: str, question: str) -> str:
173
  with open(image_path, "rb") as image_file:
174
  encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
175
 
176
- # Create a separate Vision LLM call specific to the image
177
- vision_model = ChatGroq(model="llama-3.2-90b-vision-preview", temperature=0)
 
 
 
 
 
178
 
179
  message = HumanMessage(
180
  content=[
@@ -411,6 +416,11 @@ def answer_message(state: AgentState) -> AgentState:
411
  """)]
412
  messages = prompt + messages
413
 
 
 
 
 
 
414
  # Multi-step ReAct Loop (Up to 8 reasoning steps)
415
  max_steps = 8
416
  draft_response = None
 
55
  max_retries=2,
56
  )
57
 
58
+ # OpenRouter Model (Primary Fallback)
59
  openrouter_model = ChatOpenAI(
60
  model="meta-llama/llama-3.3-70b-instruct",
61
  openai_api_key=os.getenv("OPENROUTER_API_KEY"),
 
65
 
66
  # Google AI Studio Fallback Model (Gemini)
67
  gemini_model = ChatGoogleGenerativeAI(
68
+ model="gemini-2.5-flash",
69
  # google_api_key is automatically picked up from GOOGLE_API_KEY environment variable
70
  temperature=0,
71
  )
 
80
  tertiary = gemini_with_tools if use_tools else gemini_model
81
 
82
  tiers = [
 
83
  {"name": "OpenRouter", "model": secondary, "key": "OPENROUTER_API_KEY"},
84
  {"name": "Gemini", "model": tertiary, "key": "GOOGLE_API_KEY"},
85
+ {"name": "Groq", "model": primary, "key": "GROQ_API_KEY"},
86
  ]
87
 
88
  last_exception = None
 
94
  return tier["model"].invoke(msgs)
95
  except Exception as e:
96
  err_str = str(e).lower()
97
+ # Catch rate limits, generic temporary server failures, or missing models
98
+ if any(x in err_str for x in ["rate_limit", "429", "500", "503", "overloaded", "not_found", "404"]):
99
  print(f"--- {tier['name']} Error: {e}. Falling back... ---")
100
  last_exception = e
101
  continue
 
173
  with open(image_path, "rb") as image_file:
174
  encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
175
 
176
+ # Use OpenRouter for Vision as a more robust fallback
177
+ vision_model = ChatOpenAI(
178
+ model="google/gemini-2.0-flash-001",
179
+ openai_api_key=os.getenv("OPENROUTER_API_KEY"),
180
+ openai_api_base="https://openrouter.ai/api/v1",
181
+ temperature=0,
182
+ )
183
 
184
  message = HumanMessage(
185
  content=[
 
416
  """)]
417
  messages = prompt + messages
418
 
419
+ # Force tool usage if image path is detected
420
+ for msg in state["messages"]:
421
+ if isinstance(msg, HumanMessage) and "[Attached File Local Path:" in msg.content:
422
+ messages.append(HumanMessage(content="IMPORTANT: I see an image path in the message. I MUST call the analyze_image tool IMMEDIATELY in my next step to see it."))
423
+
424
  # Multi-step ReAct Loop (Up to 8 reasoning steps)
425
  max_steps = 8
426
  draft_response = None
app copy.py CHANGED
@@ -57,7 +57,7 @@ questions_url = f"{DEFAULT_API_URL}/questions"
57
  response = requests.get(questions_url, timeout=15)
58
  response.raise_for_status()
59
  questions_data = response.json()
60
- for item in questions_data[:5]:
61
  question_text = item.get("question")
62
  if question_text is None:
63
  continue
 
57
  response = requests.get(questions_url, timeout=15)
58
  response.raise_for_status()
59
  questions_data = response.json()
60
+ for item in questions_data[3:4]:
61
  question_text = item.get("question")
62
  if question_text is None:
63
  continue
test_image_tool.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from agent import build_graph
3
+ from langchain_core.messages import HumanMessage, ToolMessage
4
+ from dotenv import load_dotenv
5
+
6
+ load_dotenv()
7
+
8
+ def test_image_process():
9
+ graph = build_graph()
10
+ question = "Review the chess position in the image: [Attached File Local Path: C:\\Users\\Admin\\.cache\\huggingface\\hub\\datasets--gaia-benchmark--GAIA\\snapshots\\682dd723ee1e1697e00360edccf2366dc8418dd9\\2023\\validation\\cca530fc-4052-43b2-b130-b30968d8aa44.png]"
11
+
12
+ print(f"--- Testing with question: {question} ---")
13
+ try:
14
+ result = graph.invoke({"messages": [HumanMessage(content=question)]})
15
+
16
+ # Log flow
17
+ for msg in result["messages"]:
18
+ if hasattr(msg, "tool_calls") and msg.tool_calls:
19
+ print(f"Model called tool: {msg.tool_calls[0]['name']}")
20
+ elif isinstance(msg, ToolMessage):
21
+ print(f"Tool returned: {msg.content[:100]}...")
22
+ elif hasattr(msg, "content") and msg.content:
23
+ if "FINAL ANSWER" in msg.content:
24
+ print(f"Final Answer Found: {msg.content}")
25
+
26
+ except Exception as e:
27
+ print(f"Error: {e}")
28
+
29
+ if __name__ == "__main__":
30
+ test_image_process()