Paperbag commited on
Commit
5b69a26
·
1 Parent(s): 81dfa52

increase models

Browse files
Files changed (2) hide show
  1. __pycache__/agent.cpython-312.pyc +0 -0
  2. agent.py +45 -19
__pycache__/agent.cpython-312.pyc CHANGED
Binary files a/__pycache__/agent.cpython-312.pyc and b/__pycache__/agent.cpython-312.pyc differ
 
agent.py CHANGED
@@ -70,41 +70,66 @@ gemini_model = ChatGoogleGenerativeAI(
70
  temperature=0,
71
  )
72
 
73
- def smart_invoke(msgs, use_tools=False):
74
  """
75
- Tiered fallback: Groq -> OpenRouter -> Google AI Studio.
76
- Retries next tier if a 429 (rate limit) or server-side error occurs.
77
  """
78
  primary = model_with_tools if use_tools else model
79
  secondary = openrouter_with_tools if use_tools else openrouter_model
80
  tertiary = gemini_with_tools if use_tools else gemini_model
81
 
 
 
 
82
  tiers = [
83
  {"name": "OpenRouter", "model": secondary, "key": "OPENROUTER_API_KEY"},
84
- {"name": "Gemini", "model": tertiary, "key": "GOOGLE_API_KEY"},
85
  {"name": "Groq", "model": primary, "key": "GROQ_API_KEY"},
86
  ]
87
 
88
  last_exception = None
89
- for tier in tiers:
 
90
  if not os.getenv(tier["key"]):
91
- continue # Skip if no API key
92
 
93
- try:
94
- return tier["model"].invoke(msgs)
95
- except Exception as e:
96
- err_str = str(e).lower()
97
- # Catch rate limits, generic temporary server failures, or missing models
98
- if any(x in err_str for x in ["rate_limit", "429", "500", "503", "overloaded", "not_found", "404"]):
99
- print(f"--- {tier['name']} Error: {e}. Falling back... ---")
100
- last_exception = e
101
- continue
102
- raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  if last_exception:
105
  print("CRITICAL: All fallback tiers failed.")
106
  raise last_exception
107
- return None
108
 
109
  @tool
110
  def web_search(keywords: str) -> str:
@@ -424,6 +449,7 @@ def answer_message(state: AgentState) -> AgentState:
424
  # Multi-step ReAct Loop (Up to 8 reasoning steps)
425
  max_steps = 8
426
  draft_response = None
 
427
 
428
  for step in range(max_steps):
429
  if step > 0:
@@ -431,7 +457,7 @@ def answer_message(state: AgentState) -> AgentState:
431
  time.sleep(4)
432
 
433
  print(f"--- ReAct Step {step + 1} ---")
434
- ai_msg = smart_invoke(messages, use_tools=True)
435
  messages.append(ai_msg)
436
 
437
  # Check if the model requested tools
@@ -475,7 +501,7 @@ def answer_message(state: AgentState) -> AgentState:
475
  "If it is a name or word, just return the exact string. If a list, return only the comma-separated list."
476
  )
477
  )
478
- final_response = smart_invoke([formatting_sys, HumanMessage(content=draft_response.content)], use_tools=False)
479
  print(f"Draft response: {draft_response.content}")
480
  print(f"Strict Final response: {final_response.content}")
481
 
 
70
  temperature=0,
71
  )
72
 
73
+ def smart_invoke(msgs, use_tools=False, start_tier=0):
74
  """
75
+ Tiered fallback: OpenRouter -> Gemini -> Groq.
76
+ Retries next tier if a 429 (rate limit), 402 (credits), or 404 (model found) error occurs.
77
  """
78
  primary = model_with_tools if use_tools else model
79
  secondary = openrouter_with_tools if use_tools else openrouter_model
80
  tertiary = gemini_with_tools if use_tools else gemini_model
81
 
82
+ # Adaptive Gemini names to try if 1.5 flash is 404
83
+ gemini_alternatives = ["gemini-2.5-flash", "gemini-2.5-flash-lite", "gemini-3.1-flash-lite", "gemini-3-flash"]
84
+
85
  tiers = [
86
  {"name": "OpenRouter", "model": secondary, "key": "OPENROUTER_API_KEY"},
87
+ {"name": "Gemini", "model": tertiary, "key": "GOOGLE_API_KEY", "alternatives": gemini_alternatives},
88
  {"name": "Groq", "model": primary, "key": "GROQ_API_KEY"},
89
  ]
90
 
91
  last_exception = None
92
+ for i in range(start_tier, len(tiers)):
93
+ tier = tiers[i]
94
  if not os.getenv(tier["key"]):
95
+ continue
96
 
97
+ # For tiers with alternatives (like Gemini), try each if 404 occurs
98
+ models_to_try = [tier["model"]]
99
+ if "alternatives" in tier:
100
+ for alt_name in tier["alternatives"]:
101
+ # Create a new model instance if the default one fails
102
+ alt_model = ChatGoogleGenerativeAI(model=alt_name, temperature=0).bind_tools(tools) if use_tools else ChatGoogleGenerativeAI(model=alt_name, temperature=0)
103
+ models_to_try.append(alt_model)
104
+
105
+ for current_model in models_to_try:
106
+ try:
107
+ model_name = getattr(current_model, "model", tier["name"])
108
+ print(f"--- Calling {tier['name']} ({model_name}) ---")
109
+ return current_model.invoke(msgs), i
110
+ except Exception as e:
111
+ err_str = str(e).lower()
112
+ # If it's a 404 (not found) and we have more alternatives, continue to the next alternative
113
+ if any(x in err_str for x in ["not_found", "404"]) and current_model != models_to_try[-1]:
114
+ print(f"--- {tier['name']} model {model_name} not found. Trying alternative... ---")
115
+ continue
116
+
117
+ # Catch other fallback triggers
118
+ if any(x in err_str for x in ["rate_limit", "429", "500", "503", "overloaded", "not_found", "404", "402", "credits"]):
119
+ print(f"--- {tier['name']} Error: {e}. Falling back... ---")
120
+ last_exception = e
121
+ break # Move to next tier
122
+ raise e
123
+
124
+ if last_exception:
125
+ print("CRITICAL: All fallback tiers failed.")
126
+ raise last_exception
127
+ return None, 0
128
 
129
  if last_exception:
130
  print("CRITICAL: All fallback tiers failed.")
131
  raise last_exception
132
+ return None, 0
133
 
134
  @tool
135
  def web_search(keywords: str) -> str:
 
449
  # Multi-step ReAct Loop (Up to 8 reasoning steps)
450
  max_steps = 8
451
  draft_response = None
452
+ current_tier = 0
453
 
454
  for step in range(max_steps):
455
  if step > 0:
 
457
  time.sleep(4)
458
 
459
  print(f"--- ReAct Step {step + 1} ---")
460
+ ai_msg, current_tier = smart_invoke(messages, use_tools=True, start_tier=current_tier)
461
  messages.append(ai_msg)
462
 
463
  # Check if the model requested tools
 
501
  "If it is a name or word, just return the exact string. If a list, return only the comma-separated list."
502
  )
503
  )
504
+ final_response, _ = smart_invoke([formatting_sys, HumanMessage(content=draft_response.content)], use_tools=False, start_tier=current_tier)
505
  print(f"Draft response: {draft_response.content}")
506
  print(f"Strict Final response: {final_response.content}")
507