Final_Assignment_Template / verify_fixes.py
Paperbag's picture
fix nvidia
3fc1414
import os
import sys
from unittest.mock import MagicMock, patch
# Mocking modules
sys.modules['cv2'] = MagicMock()
sys.modules['whisper'] = MagicMock()
# Set dummy env vars
os.environ["OPENROUTER_API_KEY"] = "dummy"
os.environ["GOOGLE_API_KEY"] = "dummy"
os.environ["GROQ_API_KEY"] = "dummy"
os.environ["NVIDIA_API_KEY"] = "dummy"
os.environ["VERCEL_API_KEY"] = "dummy"
sys.path.append(os.getcwd())
import agent
from langchain_core.messages import HumanMessage
def test_gemini_alternatives_on_rate_limit():
print("Testing Gemini alternatives on rate limit...")
# We need to mock ChatGoogleGenerativeAI to simulate rate limit on one instance but success on another
# Since they are created inside smart_invoke, we patch the class constructor or just the instances if we can
with patch('agent.openrouter_model.invoke') as mock_openrouter, \
patch('agent.ChatGoogleGenerativeAI') as mock_gemini_class:
# OpenRouter fails
mock_openrouter.side_effect = Exception("Rate limit (429)")
# First Gemini call (primary) fails with rate limit
# Second Gemini call (alternative) succeeds
mock_primary = MagicMock()
mock_primary.invoke.side_effect = Exception("Rate limit (429)")
mock_primary.model = "gemini-2.5-flash"
mock_alt = MagicMock()
mock_alt.invoke.return_value = MagicMock(content="Gemini alternative response")
mock_alt.model = "gemini-2.5-flash-lite"
# Control the sequence of ChatGoogleGenerativeAI creation
# agent.py creates gemini_model at top level, then potentially more in smart_invoke
mock_gemini_class.side_effect = [mock_alt] # The one created in the loop
# We also need to mock the already created gemini_model
with patch('agent.gemini_model', mock_primary):
msgs = [HumanMessage(content="Hello")]
response, tier_idx = agent.smart_invoke(msgs, use_tools=False)
print(f"Response from tier {tier_idx}: {response.content}")
# Tier 1 is Gemini
assert tier_idx == 1
assert response.content == "Gemini alternative response"
print("Gemini alternative on rate limit successful!")
def test_nvidia_name():
print("Checking NVIDIA model name...")
assert agent.nvidia_model.model_name == "meta/llama-3.1-405b-instruct"
print("NVIDIA model name is correct!")
if __name__ == "__main__":
try:
test_gemini_alternatives_on_rate_limit()
test_nvidia_name()
print("All fix tests passed!")
except Exception as e:
print(f"Test failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)