Paperbag commited on
Commit
3fc1414
·
1 Parent(s): 825865b

fix nvidia

Browse files
Files changed (3) hide show
  1. __pycache__/agent.cpython-312.pyc +0 -0
  2. agent.py +5 -2
  3. verify_fixes.py +72 -0
__pycache__/agent.cpython-312.pyc CHANGED
Binary files a/__pycache__/agent.cpython-312.pyc and b/__pycache__/agent.cpython-312.pyc differ
 
agent.py CHANGED
@@ -73,7 +73,7 @@ gemini_model = ChatGoogleGenerativeAI(
73
 
74
  # NVIDIA Model (Secondary Fallback)
75
  nvidia_model = ChatOpenAI(
76
- model="nvidia/llama-3.1-405b-instruct",
77
  openai_api_key=os.getenv("NVIDIA_API_KEY"),
78
  openai_api_base="https://integrate.api.nvidia.com/v1",
79
  temperature=0,
@@ -138,8 +138,11 @@ def smart_invoke(msgs, use_tools=False, start_tier=0):
138
 
139
  # Catch other fallback triggers
140
  if any(x in err_str for x in ["rate_limit", "429", "500", "503", "overloaded", "not_found", "404", "402", "credits"]):
141
- print(f"--- {tier['name']} Error: {e}. Falling back... ---")
142
  last_exception = e
 
 
 
143
  break # Move to next tier
144
  raise e
145
 
 
73
 
74
  # NVIDIA Model (Secondary Fallback)
75
  nvidia_model = ChatOpenAI(
76
+ model="meta/llama-3.1-405b-instruct",
77
  openai_api_key=os.getenv("NVIDIA_API_KEY"),
78
  openai_api_base="https://integrate.api.nvidia.com/v1",
79
  temperature=0,
 
138
 
139
  # Catch other fallback triggers
140
  if any(x in err_str for x in ["rate_limit", "429", "500", "503", "overloaded", "not_found", "404", "402", "credits"]):
141
+ print(f"--- {tier['name']} Error: {e}. Trying next model/tier... ---")
142
  last_exception = e
143
+ # If this tier has more alternatives, continue to the next one
144
+ if current_model != models_to_try[-1]:
145
+ continue
146
  break # Move to next tier
147
  raise e
148
 
verify_fixes.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from unittest.mock import MagicMock, patch
4
+
5
+ # Mocking modules
6
+ sys.modules['cv2'] = MagicMock()
7
+ sys.modules['whisper'] = MagicMock()
8
+
9
+ # Set dummy env vars
10
+ os.environ["OPENROUTER_API_KEY"] = "dummy"
11
+ os.environ["GOOGLE_API_KEY"] = "dummy"
12
+ os.environ["GROQ_API_KEY"] = "dummy"
13
+ os.environ["NVIDIA_API_KEY"] = "dummy"
14
+ os.environ["VERCEL_API_KEY"] = "dummy"
15
+
16
+ sys.path.append(os.getcwd())
17
+
18
+ import agent
19
+ from langchain_core.messages import HumanMessage
20
+
21
+ def test_gemini_alternatives_on_rate_limit():
22
+ print("Testing Gemini alternatives on rate limit...")
23
+
24
+ # We need to mock ChatGoogleGenerativeAI to simulate rate limit on one instance but success on another
25
+ # Since they are created inside smart_invoke, we patch the class constructor or just the instances if we can
26
+
27
+ with patch('agent.openrouter_model.invoke') as mock_openrouter, \
28
+ patch('agent.ChatGoogleGenerativeAI') as mock_gemini_class:
29
+
30
+ # OpenRouter fails
31
+ mock_openrouter.side_effect = Exception("Rate limit (429)")
32
+
33
+ # First Gemini call (primary) fails with rate limit
34
+ # Second Gemini call (alternative) succeeds
35
+ mock_primary = MagicMock()
36
+ mock_primary.invoke.side_effect = Exception("Rate limit (429)")
37
+ mock_primary.model = "gemini-2.5-flash"
38
+
39
+ mock_alt = MagicMock()
40
+ mock_alt.invoke.return_value = MagicMock(content="Gemini alternative response")
41
+ mock_alt.model = "gemini-2.5-flash-lite"
42
+
43
+ # Control the sequence of ChatGoogleGenerativeAI creation
44
+ # agent.py creates gemini_model at top level, then potentially more in smart_invoke
45
+ mock_gemini_class.side_effect = [mock_alt] # The one created in the loop
46
+
47
+ # We also need to mock the already created gemini_model
48
+ with patch('agent.gemini_model', mock_primary):
49
+ msgs = [HumanMessage(content="Hello")]
50
+ response, tier_idx = agent.smart_invoke(msgs, use_tools=False)
51
+
52
+ print(f"Response from tier {tier_idx}: {response.content}")
53
+ # Tier 1 is Gemini
54
+ assert tier_idx == 1
55
+ assert response.content == "Gemini alternative response"
56
+ print("Gemini alternative on rate limit successful!")
57
+
58
+ def test_nvidia_name():
59
+ print("Checking NVIDIA model name...")
60
+ assert agent.nvidia_model.model_name == "meta/llama-3.1-405b-instruct"
61
+ print("NVIDIA model name is correct!")
62
+
63
+ if __name__ == "__main__":
64
+ try:
65
+ test_gemini_alternatives_on_rate_limit()
66
+ test_nvidia_name()
67
+ print("All fix tests passed!")
68
+ except Exception as e:
69
+ print(f"Test failed: {e}")
70
+ import traceback
71
+ traceback.print_exc()
72
+ sys.exit(1)